From c038b711911bf7d42f25307b874b7efa04619aec Mon Sep 17 00:00:00 2001 From: Fiifi Botchway Date: Thu, 25 Jun 2026 12:20:10 +0100 Subject: [PATCH 1/4] feat(ontology): add PGE Evaluator stage to owl-generator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Turn owl-generation into a real Planner→Generator→Evaluator loop. After the pitfall-tool fix loop settles, a deterministic Stage-1 evaluator scores the ontology against source metadata and feeds concrete retry-hints back to the generator on Tier-1 structural defects (orphan classes, dangling domain/range, naming violations, duplicate classes), bounded by MAX_OWL_EVAL_ROUNDS. - engine.py: _evaluate_ontology_stage() + loop wiring (fails open; never discards a usable ontology), MAX_OUTPUT_TOKENS=16000, exhaustive ATTRIBUTE COVERAGE prompt + get_table_detail workflow step. - New agents/pge_eval slice: normalize.py + ontology_metrics.evaluate_ontology (gold-free, intrinsic; minimal package root to avoid coupling). - Tests: ontology_metrics + owl_evaluator_stage (39 targeted, 565 unit green). Co-authored-by: Isaac --- changelogs/v0.5.2/FiifiB_2026-06-25.log | 58 ++++ src/agents/agent_owl_generator/engine.py | 173 ++++++++++- src/agents/pge_eval/__init__.py | 16 ++ src/agents/pge_eval/normalize.py | 262 +++++++++++++++++ src/agents/pge_eval/ontology_metrics.py | 270 ++++++++++++++++++ tests/units/pge_eval/__init__.py | 0 tests/units/pge_eval/_fixtures.py | 178 ++++++++++++ tests/units/pge_eval/test_ontology_metrics.py | 83 ++++++ .../pge_eval/test_owl_evaluator_stage.py | 69 +++++ 9 files changed, 1100 insertions(+), 9 deletions(-) create mode 100644 changelogs/v0.5.2/FiifiB_2026-06-25.log create mode 100644 src/agents/pge_eval/__init__.py create mode 100644 src/agents/pge_eval/normalize.py create mode 100644 src/agents/pge_eval/ontology_metrics.py create mode 100644 tests/units/pge_eval/__init__.py create mode 100644 tests/units/pge_eval/_fixtures.py create mode 100644 tests/units/pge_eval/test_ontology_metrics.py create mode 100644 tests/units/pge_eval/test_owl_evaluator_stage.py diff --git a/changelogs/v0.5.2/FiifiB_2026-06-25.log b/changelogs/v0.5.2/FiifiB_2026-06-25.log new file mode 100644 index 00000000..8a23f7f6 --- /dev/null +++ b/changelogs/v0.5.2/FiifiB_2026-06-25.log @@ -0,0 +1,58 @@ +# 2026-06-25 — feat(ontology): PGE Evaluator stage for owl-generator + +## Context + +The owl-generator agent had a single-shot generation + a pitfall-tool fix loop, +but no deterministic Evaluator stage — so structural defects (orphan classes, +dangling domain/range, naming violations, duplicate classes) could survive into +the delivered ontology. This change turns owl-generation into a real +Planner→Generator→Evaluator (PGE) loop: after the pitfall loop settles, a +deterministic Stage-1 evaluator scores the ontology against the source metadata +and feeds concrete retry-hints back to the generator, bounded by a hard cap. + +The Evaluator reuses a small, usecase-agnostic ontology-metrics module +(`agents.pge_eval.ontology_metrics`) — gold-free, computed purely from the +generated ontology + source schema. Only the ontology slice of the metrics +package is introduced here; the full scorecard/CLI lands separately. + +## Changes + +1. `src/agents/agent_owl_generator/engine.py` + - Add `MAX_OWL_EVAL_ROUNDS` (bounded Evaluator retry cap) and + `_evaluate_ontology_stage()` — parses the Turtle, runs the deterministic + Tier-1 ontology checks, and returns a retry-hint string on hard defects + (orphan / dangling domain-range / naming / duplicate). Fails open: any + parse/dep error returns `None` so a check failure never blocks delivery. + - Wire the Evaluator into the agent loop after the pitfall loop; only retry + when an iteration remains, so a usable ontology is never discarded by + exhausting `MAX_ITERATIONS`. + - Raise `max_tokens` to `MAX_OUTPUT_TOKENS = 16000` so exhaustive attribute + coverage isn't silently truncated past the old 4096 ceiling. + - Strengthen the system prompt: `# ATTRIBUTE COVERAGE` section + a + `get_table_detail`-per-table workflow step driving exhaustive (not curated) + datatype-property coverage. +2. `src/agents/pge_eval/__init__.py` — new package (minimal root; importers + depend on the concrete submodule to avoid coupling to later modules). +3. `src/agents/pge_eval/normalize.py` — shared name/metadata/ontology + normalization primitives (stdlib-only). +4. `src/agents/pge_eval/ontology_metrics.py` — `evaluate_ontology()`: + deterministic Stage-1 checks + footprint coverage, no stored reference. +5. Tests: `tests/units/pge_eval/{__init__,_fixtures}.py`, + `test_ontology_metrics.py`, `test_owl_evaluator_stage.py`. + +## Modified / added files + +- M src/agents/agent_owl_generator/engine.py +- A src/agents/pge_eval/__init__.py +- A src/agents/pge_eval/normalize.py +- A src/agents/pge_eval/ontology_metrics.py +- A tests/units/pge_eval/__init__.py +- A tests/units/pge_eval/_fixtures.py +- A tests/units/pge_eval/test_ontology_metrics.py +- A tests/units/pge_eval/test_owl_evaluator_stage.py + +## Tests + +`uv run pytest tests/units/pge_eval/test_ontology_metrics.py +tests/units/pge_eval/test_owl_evaluator_stage.py +tests/units/ontology/test_owl_generator.py -q` → **39 passed**. diff --git a/src/agents/agent_owl_generator/engine.py b/src/agents/agent_owl_generator/engine.py index 9d755d48..c16f4ad1 100644 --- a/src/agents/agent_owl_generator/engine.py +++ b/src/agents/agent_owl_generator/engine.py @@ -37,6 +37,18 @@ MAX_ITERATIONS = 10 LLM_TIMEOUT = 180 +# Exhaustive per-class datatype-property coverage (see # ATTRIBUTE COVERAGE in +# the system prompt) makes the Turtle output large — a large domain ontology +# with dozens of classes and 50+ datatype properties runs well past the old 4096 +# ceiling, which silently truncated the final statement and broke parsing. +# Claude Opus supports large completions; 16k tokens fits an exhaustive +# domain ontology with headroom. +MAX_OUTPUT_TOKENS = 16000 + +# Bounded PGE retry cap for the Evaluator stage (§3.5): how many times the +# deterministic Stage-1 ontology checks may feed retry_hints back into +# generation before owl delivery proceeds regardless. +MAX_OWL_EVAL_ROUNDS = 2 _TRACE_NAME = "owl_generator" @@ -95,9 +107,12 @@ def _load_pitfall_rules() -> str: # WORKFLOW 1. Call get_metadata to understand the database schema. -2. Call list_documents to discover available documents. -3. Read relevant documents with read_document. -4. Output ONLY the final Turtle ontology as plain text (starting with @prefix). +2. Call get_table_detail on EVERY table you intend to map a class to — get_metadata + truncates wide tables at 80 columns, and you must see the FULL column list to give + each class exhaustive attribute coverage (see # ATTRIBUTE COVERAGE). +3. Call list_documents to discover available documents. +4. Read relevant documents with read_document. +5. Output ONLY the final Turtle ontology as plain text (starting with @prefix). # NAMING RULES (CRITICAL – NO EXCEPTIONS) • Classes: PascalCase (Customer, SalesOrder) @@ -127,6 +142,34 @@ def _load_pitfall_rules() -> str: • For EVERY DatatypeProperty you MUST declare rdfs:domain on the property itself (do not rely on owl:Restriction alone — the platform reads attributes from rdfs:domain) +# ATTRIBUTE COVERAGE (CRITICAL — exhaustive, NOT curated) +The downstream mapping pipeline can only bind a SQL column to a class when that +column has a matching owl:DatatypeProperty with rdfs:domain on the class. A class +with few datatype properties produces an ID+Label-only entity that is USELESS for +analytics. So model attributes EXHAUSTIVELY, not minimally: +• For EVERY class, emit a DatatypeProperty for EVERY meaningful source column that + describes an instance of that class — across ALL tables that realise the class. + A single class is often realised by several source tables (e.g. one per source + system, region, or tenant) that each hold the same real-world entity in a local + schema; UNION their columns mentally and cover the full set. Use get_table_detail + on each covering table to see every column. +• "Meaningful" = a genuine attribute of the entity: dates, measurements, codes, + scores, names, statuses, flags, free-text notes. EXCLUDE ONLY: surrogate/auto- + increment row keys with no analytical value, audit columns (created_at, updated_by, + etl_*, _ingest_*), and the foreign-key columns that ObjectProperty relationships + already carry. +• When two sources expose the SAME attribute under different column names + (e.g. total_amount vs TOTAL_AMT; status vs STATUS_CODE), emit ONE datatype + property — do NOT emit a per-source duplicate. The mapping layer reconciles the + source columns. +• Name datatype properties in lowerCamelCase derived from business meaning + (order_date → orderDate, TOTAL_AMT → totalAmount). + Use ONLY [a-z][A-Za-z0-9]* — never underscores, hyphens, or backslash escapes. +• The "at least 2 datatype properties" floor in the guidelines is a MINIMUM, not a + target. Rich, real-world entities (a transaction, an encounter, an event, a core + business object) typically warrant 6–11 datatype properties. Aim for full column + coverage, not a tidy subset. + # RELATIONSHIP RULES • NEVER create bidirectional relationships. • Between any two classes A and B create at most ONE ObjectProperty. @@ -160,10 +203,12 @@ def _load_pitfall_rules() -> str: ## 2. Class and property design rules For each **class** you create:[1][2][3][4] -1. Provide: - - A short, clear natural-language definition (1–2 sentences). - - At least 1 object property (unless the class is explicitly abstract). - - At least 2 datatype properties, when meaningful in the domain. +1. Provide: + - A short, clear natural-language definition (1–2 sentences). + - At least 1 object property (unless the class is explicitly abstract). + - Datatype properties covering EVERY meaningful source column for the class + (see "# ATTRIBUTE COVERAGE" in the system prompt — exhaustive, not curated; + 2 is a floor, full column coverage is the goal). 2. Naming conventions: - Classes: UpperCamelCase (e.g., `CustomerOrder`). - Object properties: lowerCamelCase verbs or verb-like phrases (e.g., `placesOrder`). @@ -241,6 +286,80 @@ def _parse_pitfall_tool_result(tool_result_json: str) -> Optional[Dict]: return None +# Stage-1 absolute (Tier-1) ontology defects that the Evaluator forces a +# retry on. Coverage ratios are computed and logged but are advisory at the +# generation stage (they are Tier-2 in the scorecard), so they do not by +# themselves trigger a regeneration — only hard structural defects do. +_EVAL_ABSOLUTE_CHECKS = ( + "orphan_class_count", + "dangling_domain_range_count", + "naming_violation_count", + "duplicate_class_count", +) + + +def _evaluate_ontology_stage( + turtle_text: str, metadata: dict, iteration: int +) -> Optional[str]: + """Run the Stage-1 deterministic ontology checks (§3.2) on *turtle_text*. + + Parses the Turtle into the registry shape, runs the shared intrinsic + checks, and returns a concrete ``retry_hint`` feedback string when any + Tier-1 absolute defect (orphan class, dangling domain/range, naming + violation, duplicate class) is present — turning owl-gen into a real + PGE loop. Returns ``None`` when the ontology is structurally clean. + + Fails open: any parse/dep error returns ``None`` so a check failure + never blocks OWL delivery (mirrors the pitfall-tool check). + """ + try: + from back.core.w3c.owl.OntologyParser import OntologyParser + from back.objects.ontology.Ontology import Ontology + from agents.pge_eval.ontology_metrics import evaluate_ontology + + # The model sometimes prepends a prose sentence or wraps the Turtle in + # a markdown fence; strip that the same way the downstream registry + # does, so the Evaluator parses real output instead of skipping. + turtle_text = Ontology.clean_owl_output(turtle_text) + parser = OntologyParser(turtle_text) + ontology = { + "classes": parser.get_classes(), + "properties": parser.get_properties(), + } + metrics, issues, _footprint = evaluate_ontology(ontology, metadata or {}) + logger.info( + "Iteration %d: ontology evaluator — metrics=%s", + iteration, + metrics, + ) + + absolute_issues = [ + i for i in issues if i.get("check") in _EVAL_ABSOLUTE_CHECKS + ] + if not absolute_issues: + logger.info( + "Iteration %d: ontology evaluator — no Tier-1 defects", iteration + ) + return None + + lines = [ + "The ontology you produced has structural defects. Fix ALL of them " + "and output ONLY the corrected Turtle (no markdown, no comments, " + "starting with @prefix declarations):\n" + ] + # Cap feedback to keep the prompt bounded. + for issue in absolute_issues[:12]: + lines.append(f" • {issue['hint']}") + return "\n".join(lines) + except Exception as exc: # noqa: BLE001 + logger.warning( + "Iteration %d: ontology evaluator skipped due to error: %s", + iteration, + exc, + ) + return None + + def _build_user_prompt( guidelines: str, options: dict, @@ -443,6 +562,7 @@ def notify(msg: str): # ------------------------------------------------------------------ tools_supported = True _owl_fix_rounds = 0 # pitfall-fix rounds consumed so far + _owl_eval_rounds = 0 # Evaluator (Stage-1 PGE) retry rounds consumed for iteration in range(MAX_ITERATIONS): logger.info( @@ -477,7 +597,7 @@ def notify(msg: str): endpoint_name, messages, tools=send_tools, - max_tokens=4096, + max_tokens=MAX_OUTPUT_TOKENS, temperature=0.1, timeout=LLM_TIMEOUT, trace_name=_TRACE_NAME, @@ -509,7 +629,7 @@ def notify(msg: str): endpoint_name, messages, tools=None, - max_tokens=4096, + max_tokens=MAX_OUTPUT_TOKENS, temperature=0.1, timeout=LLM_TIMEOUT, trace_name=_TRACE_NAME, @@ -749,6 +869,41 @@ def notify(msg: str): _owl_fix_rounds, max_fix_rounds, ) + # -------------------------------------------------------------- + # Evaluator stage (PGE loop) — after the pitfall-tool loop is + # clean/maxed, run the Stage-1 deterministic ontology checks (§3.2). + # On a Tier-1 structural defect, feed concrete retry_hints back to + # the generator, bounded by MAX_OWL_EVAL_ROUNDS. Only retry when + # there's another iteration left, so a usable ontology is never + # discarded by exhausting MAX_ITERATIONS. + # -------------------------------------------------------------- + eval_feedback = _evaluate_ontology_stage(content, ctx.metadata, iteration + 1) + if ( + eval_feedback + and _owl_eval_rounds < MAX_OWL_EVAL_ROUNDS + and iteration < MAX_ITERATIONS - 1 + ): + _owl_eval_rounds += 1 + notify( + f"Ontology defects found — eval round " + f"{_owl_eval_rounds}/{MAX_OWL_EVAL_ROUNDS}…" + ) + result.steps.append( + AgentStep( + step_type="evaluator", + content=eval_feedback[:200], + duration_ms=0, + ) + ) + messages.append({"role": "assistant", "content": content}) + messages.append({"role": "user", "content": eval_feedback}) + logger.info( + "Iteration %d: ontology evaluator found defects — eval round %d", + iteration + 1, + _owl_eval_rounds, + ) + continue # next iteration will produce corrected OWL + # ── Accept this text as the final OWL ──────────────────────────── result.success = True result.owl_content = content diff --git a/src/agents/pge_eval/__init__.py b/src/agents/pge_eval/__init__.py new file mode 100644 index 00000000..5a2089d9 --- /dev/null +++ b/src/agents/pge_eval/__init__.py @@ -0,0 +1,16 @@ +"""OntoBricks PGE intrinsic-evaluation primitives. + +This package holds usecase-agnostic, gold-free structural metrics for the PGE +pipeline. This PR introduces only the **ontology** slice consumed by the +owl-generator Evaluator stage: + +* :func:`agents.pge_eval.ontology_metrics.evaluate_ontology` — Stage-1 + deterministic ontology checks (orphan classes, dangling domain/range, + naming, duplicates, footprint coverage), computed purely from the generated + ontology + source metadata (no stored reference answer). + +The full scorecard (mapping metrics, gate tiers, baseline regression, LLM +judge, CLI) lands in a separate change. Importers should depend on the +concrete submodule (``agents.pge_eval.ontology_metrics``) rather than this +package root to avoid coupling to modules introduced later. +""" diff --git a/src/agents/pge_eval/normalize.py b/src/agents/pge_eval/normalize.py new file mode 100644 index 00000000..3e20baab --- /dev/null +++ b/src/agents/pge_eval/normalize.py @@ -0,0 +1,262 @@ +"""Shape normalisation + footprint helpers for the PGE intrinsic evaluator. + +Everything in this module is pure Python — no LLM, no DB, no domain +knowledge. It exists so the rest of the scorer can reason over one stable +in-memory shape regardless of whether the caller handed it the *agent* +ontology shape (``{entities, relationships}``), the *registry* ontology +shape (``{classes, properties}``), or raw source metadata. + +Design constraints (see docs/plans/2026-06-10-goal-loop-and-pge-eval-design.md): + +* **Usecase-agnostic.** No table name, identifier, or count from any + particular domain is encoded here. The only constants are generic + audit/surrogate column heuristics that hold for any relational source. +* **Deterministic.** Pure functions of their inputs; no randomness, no + wall-clock, no network. +""" + +from __future__ import annotations + +import re +from typing import Any, Dict, List, Optional, Set + + +# ===================================================== +# Name normalisation +# ===================================================== + + +def normalize_name(name: Optional[str]) -> str: + """Collapse a column / property / class name to a comparison key. + + Lower-cases and strips every non-alphanumeric character so that + ``first_name``, ``firstName`` and ``FirstName`` all collapse to + ``firstname``. This is the footprint-matching key used to decide + whether a source column "became" a data property without consulting + the mapping (Stage-1 is mapping-independent — see D2/D3). + """ + if not name: + return "" + return re.sub(r"[^a-z0-9]", "", str(name).lower()) + + +def local_name(uri_or_name: Optional[str]) -> str: + """Return the local name of a URI/CURIE, or the value unchanged. + + ``http://x/Customer`` -> ``Customer``; ``ex:Customer`` -> ``Customer``; + ``Customer`` -> ``Customer``. + """ + if not uri_or_name: + return "" + s = str(uri_or_name) + for sep in ("#", "/"): + if sep in s: + s = s.rsplit(sep, 1)[-1] + if ":" in s and not s.startswith("http"): + s = s.rsplit(":", 1)[-1] + return s + + +# ===================================================== +# Audit / surrogate column heuristics (generic, not domain-specific) +# ===================================================== + +# Audit tokens that mark a column as non-analytical bookkeeping. These are +# generic ETL/CDC conventions, not tied to any domain. +_AUDIT_TOKENS = ( + "createdat", + "updatedat", + "createdon", + "updatedon", + "createdby", + "updatedby", + "modifiedat", + "modifiedby", + "deletedat", + "ingestedat", + "loadedat", + "loadts", + "etltimestamp", + "dwcreated", + "dwupdated", +) +_AUDIT_PREFIXES = ("etl", "ingest", "_ingest", "dw") +# Exact surrogate row-key names + suffixes for warehouse surrogate keys. +_SURROGATE_EXACT = ("id", "rowid", "rownum", "rownumber") +_SURROGATE_SUFFIXES = ("sk", "surrogatekey") + + +def is_surrogate_or_audit(column_name: str) -> bool: + """Heuristic: True when *column_name* is a surrogate row key or audit + column with no analytical value. + + The OWL generator is instructed to drop exactly these, so they are + excluded from coverage denominators (D3). Intentionally conservative: + it does NOT drop every ``*_id`` column (foreign keys can be meaningful), + only obvious surrogate keys and audit bookkeeping. + """ + norm = normalize_name(column_name) + if not norm: + return True + if norm in _SURROGATE_EXACT: + return True + if any(norm.endswith(sfx) for sfx in _SURROGATE_SUFFIXES): + return True + if any(tok in norm for tok in _AUDIT_TOKENS): + return True + raw = re.sub(r"[^a-z0-9_]", "", str(column_name).lower()) + if any(raw.startswith(p) for p in _AUDIT_PREFIXES): + return True + return False + + +# ===================================================== +# Ontology normalisation +# ===================================================== + + +def _attr_names(raw_attrs: Any) -> List[str]: + """Normalise an attribute container to a flat list of name strings. + + Accepts the agent shape (list of str or ``{name|uri|label}`` dicts) and + the registry shape (list of ``{name|localName}`` dicts). + """ + out: List[str] = [] + for a in raw_attrs or []: + if isinstance(a, str): + out.append(a) + elif isinstance(a, dict): + name = a.get("name") or a.get("localName") or a.get("uri") or a.get("label") + if name: + out.append(local_name(name)) + return out + + +class NormalizedOntology: + """A flat, shape-agnostic view of a generated ontology. + + Attributes: + classes: list of ``{"name", "uri", "data_properties": [str]}``. + object_properties: list of ``{"name", "uri", "domain", "range"}`` + where domain/range are the raw refs as authored (URI or local). + """ + + def __init__(self, classes: List[dict], object_properties: List[dict]): + self.classes = classes + self.object_properties = object_properties + + # --- derived sets, computed lazily but cheaply ------------------ + + @property + def class_resolution_set(self) -> Set[str]: + """Every token a domain/range ref could legitimately resolve to.""" + out: Set[str] = set() + for c in self.classes: + if c.get("uri"): + out.add(c["uri"]) + out.add(local_name(c["uri"])) + if c.get("name"): + out.add(c["name"]) + out.add(local_name(c["name"])) + return out + + @property + def all_data_property_keys(self) -> Set[str]: + """Normalised keys of every data property across every class.""" + keys: Set[str] = set() + for c in self.classes: + for dp in c.get("data_properties", []): + k = normalize_name(local_name(dp)) + if k: + keys.add(k) + return keys + + @property + def class_name_keys(self) -> Set[str]: + keys: Set[str] = set() + for c in self.classes: + k = normalize_name(local_name(c.get("name") or c.get("uri"))) + if k: + keys.add(k) + return keys + + +def normalize_ontology(ontology: dict) -> NormalizedOntology: + """Normalise either the agent shape or the registry shape. + + * Agent shape: ``{"entities": [...], "relationships": [...]}`` + * Registry shape: ``{"classes": [...], "properties": [...]}`` + """ + ontology = ontology or {} + classes: List[dict] = [] + object_props: List[dict] = [] + + if "entities" in ontology or "relationships" in ontology: + for e in ontology.get("entities", []) or []: + classes.append( + { + "name": e.get("name") or local_name(e.get("uri")), + "uri": e.get("uri", ""), + "data_properties": _attr_names(e.get("attributes")), + } + ) + for r in ontology.get("relationships", []) or []: + object_props.append( + { + "name": r.get("name") or local_name(r.get("uri")), + "uri": r.get("uri", ""), + "domain": r.get("domain", ""), + "range": r.get("range", ""), + } + ) + else: + for c in ontology.get("classes", []) or []: + classes.append( + { + "name": c.get("name") or local_name(c.get("uri")), + "uri": c.get("uri", ""), + "data_properties": _attr_names(c.get("dataProperties")), + } + ) + for p in ontology.get("properties", []) or []: + if p.get("type") and p.get("type") != "ObjectProperty": + continue + object_props.append( + { + "name": p.get("name") or local_name(p.get("uri")), + "uri": p.get("uri", ""), + "domain": p.get("domain", ""), + "range": p.get("range", ""), + } + ) + + return NormalizedOntology(classes=classes, object_properties=object_props) + + +# ===================================================== +# Source-metadata normalisation +# ===================================================== + + +def normalize_metadata(metadata: dict) -> List[dict]: + """Return ``[{"name", "columns": [str]}]`` from domain metadata. + + Accepts the ``{"tables": [{"name"|"full_name", "columns": [...]}]}`` + shape produced by the metadata tools. Column entries may be plain + strings or ``{"name": ...}`` dicts. + """ + out: List[dict] = [] + for t in (metadata or {}).get("tables", []) or []: + cols: List[str] = [] + for c in t.get("columns", []) or []: + if isinstance(c, str): + cols.append(c) + elif isinstance(c, dict) and c.get("name"): + cols.append(c["name"]) + out.append( + { + "name": t.get("full_name") or t.get("name") or "", + "columns": cols, + } + ) + return out diff --git a/src/agents/pge_eval/ontology_metrics.py b/src/agents/pge_eval/ontology_metrics.py new file mode 100644 index 00000000..21ba2c16 --- /dev/null +++ b/src/agents/pge_eval/ontology_metrics.py @@ -0,0 +1,270 @@ +"""Stage-1 — ontology-generation quality (deterministic, no LLM). + +Computed purely from the generated ontology + source metadata. No mapping +dependency (D2) and no LLM for the deterministic part (§3.2). The same +checks back the new owl-generator Evaluator stage (§3.5): each issue carries +a concrete ``hint`` that becomes a generator retry_hint. + +All metrics are usecase-agnostic: nothing about any particular domain is +hard-coded here. +""" + +from __future__ import annotations + +import re +from typing import Any, Dict, List, Set, Tuple + +from agents.pge_eval.normalize import ( + NormalizedOntology, + is_surrogate_or_audit, + local_name, + normalize_metadata, + normalize_name, + normalize_ontology, +) + +# Naming conventions (mirror the OWL generator's NAMING RULES, domain-free). +_CLASS_RE = re.compile(r"^[A-Z][A-Za-z0-9]*$") +_PROPERTY_RE = re.compile(r"^[a-z][A-Za-z0-9]*$") + + +def _issue(check: str, expected: str, observed: str, hint: str) -> Dict[str, str]: + return {"check": check, "expected": expected, "observed": observed, "hint": hint} + + +# ===================================================== +# Footprint computation (shared with pipeline.coverage_loss) +# ===================================================== + + +def _column_key(table_name: str, column_name: str) -> str: + return f"{normalize_name(table_name)}::{normalize_name(column_name)}" + + +def compute_footprint( + ontology: NormalizedOntology, tables: List[dict] +) -> Dict[str, Any]: + """Return the ontology footprint over the source metadata. + + A *column* is covered when its normalised name matches some data + property's normalised name. A *table* is covered when its name matches + a class name OR ≥1 of its non-surrogate columns is covered (D3). + + Surrogate/audit columns are excluded from the denominators. + """ + dp_keys = ontology.all_data_property_keys + class_keys = ontology.class_name_keys + + total_columns = 0 + covered_columns: Set[str] = set() + total_tables = len(tables) + covered_tables: Set[str] = set() + + for t in tables: + tname = t["name"] + tkey = normalize_name(local_name(tname)) + table_is_covered = tkey in class_keys + for col in t["columns"]: + if is_surrogate_or_audit(col): + continue + total_columns += 1 + ckey = _column_key(tname, col) + if normalize_name(col) in dp_keys: + covered_columns.add(ckey) + table_is_covered = True + if table_is_covered: + covered_tables.add(tname) + + return { + "total_tables": total_tables, + "covered_tables": covered_tables, + "total_columns": total_columns, + "covered_columns": covered_columns, + } + + +# ===================================================== +# Stage-1 metrics + issues +# ===================================================== + + +def evaluate_ontology( + ontology: dict, + metadata: dict, +) -> Tuple[Dict[str, Any], List[Dict[str, str]], Dict[str, Any]]: + """Run the deterministic Stage-1 checks. + + Returns ``(metrics, issues, footprint)``: + + * ``metrics`` — the §3.2 metric block (ratios + absolute counts). + * ``issues`` — actionable failures (``check/expected/observed/hint``) + for the owl-gen Evaluator's retry_hints. + * ``footprint`` — covered tables/columns sets reused by + ``pipeline.coverage_loss``. + """ + norm = normalize_ontology(ontology) + tables = normalize_metadata(metadata) + footprint = compute_footprint(norm, tables) + + issues: List[Dict[str, str]] = [] + + # ---- coverage ratios (Tier-2 warn) ----------------------------- + table_cov = ( + len(footprint["covered_tables"]) / footprint["total_tables"] + if footprint["total_tables"] + else 1.0 + ) + column_cov = ( + len(footprint["covered_columns"]) / footprint["total_columns"] + if footprint["total_columns"] + else 1.0 + ) + + uncovered_tables = [ + t["name"] + for t in tables + if t["name"] not in footprint["covered_tables"] + ] + for tname in uncovered_tables: + issues.append( + _issue( + "table_footprint_coverage", + "table maps to a class or contributes >=1 data property", + "no footprint", + f"source table '{tname}' has no class and contributes no data " + "property — model it as a class, attach its columns as data " + "properties on an existing class, or justify the omission.", + ) + ) + + # ---- orphan classes (Tier-1 absolute = 0) ---------------------- + related: Set[str] = set() + for op in norm.object_properties: + for ref in (op.get("domain"), op.get("range")): + if ref: + related.add(local_name(ref)) + related.add(str(ref)) + orphan_classes: List[str] = [] + for c in norm.classes: + has_props = bool(c.get("data_properties")) + name = c.get("name") or local_name(c.get("uri")) + in_rel = name in related or local_name(c.get("uri")) in related + if not has_props and not in_rel: + orphan_classes.append(name) + issues.append( + _issue( + "orphan_class_count", + "0 orphan classes", + name, + f"class '{name}' is an orphan (no data properties and no " + "object-property domain/range) — attach properties, relate " + "it to another class, or remove it.", + ) + ) + + # ---- dangling domain/range (Tier-1 absolute = 0) --------------- + resolvable = norm.class_resolution_set + dangling_dr: List[str] = [] + for op in norm.object_properties: + opname = op.get("name") or local_name(op.get("uri")) + for role in ("domain", "range"): + ref = op.get(role) + if not ref: + dangling_dr.append(f"{opname}.{role}") + issues.append( + _issue( + "dangling_domain_range_count", + f"ObjectProperty {role} resolves to a class", + f"{opname}.{role}=", + f"ObjectProperty '{opname}' has no {role} — declare an " + f"rdfs:{role} pointing at an existing class.", + ) + ) + continue + if ref not in resolvable and local_name(ref) not in resolvable: + dangling_dr.append(f"{opname}.{role}") + issues.append( + _issue( + "dangling_domain_range_count", + f"ObjectProperty {role} resolves to a class", + f"{opname}.{role}={local_name(ref)}", + f"ObjectProperty '{opname}' has {role} " + f"'{local_name(ref)}' which resolves to no class — fix " + "the reference or add the missing class.", + ) + ) + + # ---- naming violations (Tier-1 absolute = 0) ------------------- + naming_violations: List[str] = [] + for c in norm.classes: + nm = local_name(c.get("name") or c.get("uri")) + if nm and not _CLASS_RE.match(nm): + naming_violations.append(f"class:{nm}") + issues.append( + _issue( + "naming_violation_count", + "class name is PascalCase [A-Z][A-Za-z0-9]*", + nm, + f"class '{nm}' violates PascalCase — remove spaces / " + "underscores / hyphens and capitalise (e.g. sales_order -> " + "SalesOrder).", + ) + ) + for op in norm.object_properties: + nm = local_name(op.get("name") or op.get("uri")) + if nm and not _PROPERTY_RE.match(nm): + naming_violations.append(f"property:{nm}") + issues.append( + _issue( + "naming_violation_count", + "property name is lowerCamelCase [a-z][A-Za-z0-9]*", + nm, + f"property '{nm}' violates lowerCamelCase — use " + "[a-z][A-Za-z0-9]* with no underscores/hyphens/escapes.", + ) + ) + # data properties too + for c in norm.classes: + for dp in c.get("data_properties", []): + nm = local_name(dp) + if nm and not _PROPERTY_RE.match(nm): + naming_violations.append(f"dataproperty:{nm}") + issues.append( + _issue( + "naming_violation_count", + "data property name is lowerCamelCase", + nm, + f"data property '{nm}' violates lowerCamelCase — use " + "[a-z][A-Za-z0-9]* with no underscores/hyphens/escapes.", + ) + ) + + # ---- duplicate classes (Tier-1 absolute = 0) ------------------- + seen: Dict[str, int] = {} + for c in norm.classes: + key = normalize_name(local_name(c.get("name") or c.get("uri"))) + if not key: + continue + seen[key] = seen.get(key, 0) + 1 + duplicate_class_count = sum(n - 1 for n in seen.values() if n > 1) + for key, n in seen.items(): + if n > 1: + issues.append( + _issue( + "duplicate_class_count", + "0 duplicate class local names", + f"{key} x{n}", + f"{n} classes collapse to the local name '{key}' — merge " + "them or differentiate their names/definitions.", + ) + ) + + metrics: Dict[str, Any] = { + "table_footprint_coverage": round(table_cov, 6), + "column_footprint_coverage": round(column_cov, 6), + "orphan_class_count": len(orphan_classes), + "dangling_domain_range_count": len(dangling_dr), + "naming_violation_count": len(naming_violations), + "duplicate_class_count": duplicate_class_count, + } + return metrics, issues, footprint diff --git a/tests/units/pge_eval/__init__.py b/tests/units/pge_eval/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/units/pge_eval/_fixtures.py b/tests/units/pge_eval/_fixtures.py new file mode 100644 index 00000000..71ee60d6 --- /dev/null +++ b/tests/units/pge_eval/_fixtures.py @@ -0,0 +1,178 @@ +"""Synthetic, usecase-agnostic fixtures for the PGE evaluator unit tests. + +Deliberately uses a generic e-commerce-ish toy domain (Customer / Order / +Product) so the tests prove the scorer is domain-free — none of these names +appear in the scorer code. +""" + +from copy import deepcopy + + +def clean_ontology() -> dict: + """Agent-shape ontology that is fully structurally clean.""" + return { + "entities": [ + { + "uri": "ex:Customer", + "name": "Customer", + "attributes": ["firstName", "lastName", "email"], + }, + { + "uri": "ex:Order", + "name": "Order", + "attributes": ["orderDate", "totalAmount"], + }, + { + "uri": "ex:Product", + "name": "Product", + "attributes": ["sku", "unitPrice"], + }, + ], + "relationships": [ + { + "uri": "ex:placesOrder", + "name": "placesOrder", + "domain": "ex:Customer", + "range": "ex:Order", + }, + { + "uri": "ex:containsProduct", + "name": "containsProduct", + "domain": "ex:Order", + "range": "ex:Product", + }, + ], + } + + +def clean_metadata() -> dict: + return { + "tables": [ + { + "name": "customers", + "columns": [ + {"name": "id"}, + {"name": "first_name"}, + {"name": "last_name"}, + {"name": "email"}, + {"name": "created_at"}, + ], + }, + { + "name": "orders", + "columns": [ + {"name": "id"}, + {"name": "order_date"}, + {"name": "total_amount"}, + ], + }, + { + "name": "products", + "columns": [ + {"name": "id"}, + {"name": "sku"}, + {"name": "unit_price"}, + ], + }, + ] + } + + +def clean_artifact() -> dict: + onto = clean_ontology() + meta = clean_metadata() + return { + "success": True, + "iterations": 3, + "usage": {"prompt_tokens": 1000, "completion_tokens": 400}, + "stats": {"planner_reinvocations": 0}, + "mapping_run_log": [ + {"item": "ex:Customer", "kind": "entity", "attempts": [{}], "final_status": "PASS"}, + {"item": "ex:Order", "kind": "entity", "attempts": [{}], "final_status": "PASS"}, + {"item": "ex:Product", "kind": "entity", "attempts": [{}], "final_status": "PASS"}, + {"item": "ex:placesOrder", "kind": "relationship", "attempts": [{}], "final_status": "PASS"}, + {"item": "ex:containsProduct", "kind": "relationship", "attempts": [{}], "final_status": "PASS"}, + ], + "mapping_evaluations": { + "ex:Customer": {"metrics": {"row_count": 100, "distinct_id_count": 100, "null_id_count": 0}, "failures": []}, + "ex:Order": {"metrics": {"row_count": 500, "distinct_id_count": 500, "null_id_count": 0}, "failures": []}, + "ex:Product": {"metrics": {"row_count": 50, "distinct_id_count": 50, "null_id_count": 0}, "failures": []}, + "ex:placesOrder": {"metrics": {"total_edges": 500, "dangling_source_pct": 0.0, "dangling_target_pct": 0.0}, "failures": []}, + "ex:containsProduct": {"metrics": {"total_edges": 800, "dangling_source_pct": 0.0, "dangling_target_pct": 0.0}, "failures": []}, + }, + "entity_mappings": [ + {"ontology_class": "ex:Customer", "attribute_mappings": {"firstName": "first_name", "lastName": "last_name", "email": "email"}}, + {"ontology_class": "ex:Order", "attribute_mappings": {"orderDate": "order_date", "totalAmount": "total_amount"}}, + {"ontology_class": "ex:Product", "attribute_mappings": {"sku": "sku", "unitPrice": "unit_price"}}, + ], + "relationship_mappings": [], + "steps": [{"step_type": "planner", "tool_name": "", "duration_ms": 1200}], + "ontology": onto, + "metadata": meta, + "elapsed_s": 42.5, + } + + +def artifact_with_dangling_fk() -> dict: + """Clean except one relationship has a dangling target FK > 5%.""" + art = clean_artifact() + art["mapping_evaluations"]["ex:placesOrder"]["metrics"]["dangling_target_pct"] = 0.47 + return art + + +def artifact_with_sql_failure() -> dict: + """Clean except one entity's SQL failed to execute.""" + art = clean_artifact() + art["mapping_evaluations"]["ex:Order"] = { + "metrics": {"sql_error": "UNION type mismatch"}, + "failures": [ + { + "check": "sql_execution", + "expected": "SQL executes without error", + "observed": "execution error", + "hint": "fix the SQL", + } + ], + } + # The entity drops out of PASS in the run log too (in-scope but failed). + for entry in art["mapping_run_log"]: + if entry["item"] == "ex:Order": + entry["final_status"] = "FAIL" + return art + + +def ontology_with_orphan() -> dict: + """Add a class with no data properties and no relationships.""" + onto = clean_ontology() + onto["entities"].append({"uri": "ex:Ghost", "name": "Ghost", "attributes": []}) + return onto + + +def artifact_with_orphan_class() -> dict: + art = clean_artifact() + art["ontology"] = ontology_with_orphan() + return art + + +def ontology_with_dangling_range() -> dict: + onto = clean_ontology() + onto["relationships"].append( + {"uri": "ex:refersTo", "name": "refersTo", "domain": "ex:Order", "range": "ex:Nonexistent"} + ) + return onto + + +def ontology_with_naming_violation() -> dict: + onto = clean_ontology() + onto["entities"].append( + {"uri": "ex:bad_class", "name": "bad_class", "attributes": ["someAttr"]} + ) + return onto + + +def ontology_with_duplicate_class() -> dict: + onto = clean_ontology() + onto["entities"].append( + {"uri": "ex:Customer2", "name": "Customer", "attributes": ["nickname"]} + ) + return onto diff --git a/tests/units/pge_eval/test_ontology_metrics.py b/tests/units/pge_eval/test_ontology_metrics.py new file mode 100644 index 00000000..640a61ec --- /dev/null +++ b/tests/units/pge_eval/test_ontology_metrics.py @@ -0,0 +1,83 @@ +"""Stage-1 ontology metric tests (deterministic, no LLM).""" + +import pytest + +from agents.pge_eval.ontology_metrics import evaluate_ontology +from agents.pge_eval.normalize import is_surrogate_or_audit, normalize_name + +from tests.units.pge_eval import _fixtures as fx + + +def test_clean_ontology_all_absolute_zero(): + metrics, issues, _ = evaluate_ontology(fx.clean_ontology(), fx.clean_metadata()) + assert metrics["orphan_class_count"] == 0 + assert metrics["dangling_domain_range_count"] == 0 + assert metrics["naming_violation_count"] == 0 + assert metrics["duplicate_class_count"] == 0 + assert metrics["table_footprint_coverage"] == 1.0 + assert metrics["column_footprint_coverage"] >= 0.9 + + +def test_orphan_class_detected(): + metrics, issues, _ = evaluate_ontology(fx.ontology_with_orphan(), fx.clean_metadata()) + assert metrics["orphan_class_count"] == 1 + assert any(i["check"] == "orphan_class_count" for i in issues) + + +def test_dangling_range_detected(): + metrics, issues, _ = evaluate_ontology( + fx.ontology_with_dangling_range(), fx.clean_metadata() + ) + assert metrics["dangling_domain_range_count"] == 1 + assert any(i["check"] == "dangling_domain_range_count" for i in issues) + + +def test_naming_violation_detected(): + metrics, _, _ = evaluate_ontology( + fx.ontology_with_naming_violation(), fx.clean_metadata() + ) + assert metrics["naming_violation_count"] >= 1 + + +def test_duplicate_class_detected(): + metrics, _, _ = evaluate_ontology( + fx.ontology_with_duplicate_class(), fx.clean_metadata() + ) + assert metrics["duplicate_class_count"] == 1 + + +def test_table_coverage_drops_with_unmodelled_table(): + meta = fx.clean_metadata() + meta["tables"].append({"name": "shipments", "columns": [{"name": "carrier"}]}) + metrics, issues, _ = evaluate_ontology(fx.clean_ontology(), meta) + assert metrics["table_footprint_coverage"] < 1.0 + assert any( + i["check"] == "table_footprint_coverage" for i in issues + ) + + +def test_surrogate_and_audit_columns_excluded(): + assert is_surrogate_or_audit("id") + assert is_surrogate_or_audit("created_at") + assert is_surrogate_or_audit("customer_sk") + assert is_surrogate_or_audit("etl_load_ts") + assert not is_surrogate_or_audit("first_name") + assert not is_surrogate_or_audit("customer_id") # FK can be meaningful + + +def test_name_normalization(): + assert normalize_name("first_name") == normalize_name("firstName") == "firstname" + assert normalize_name("Order Date") == "orderdate" + + +def test_registry_shape_accepted(): + # Same ontology in registry (classes/properties) shape must score identically. + registry = { + "classes": [ + {"uri": "ex:A", "name": "A", "dataProperties": [{"name": "x"}]}, + ], + "properties": [], + } + metrics, _, _ = evaluate_ontology(registry, {"tables": []}) + # A has a data property -> not an orphan. + assert metrics["orphan_class_count"] == 0 diff --git a/tests/units/pge_eval/test_owl_evaluator_stage.py b/tests/units/pge_eval/test_owl_evaluator_stage.py new file mode 100644 index 00000000..ec23bf47 --- /dev/null +++ b/tests/units/pge_eval/test_owl_evaluator_stage.py @@ -0,0 +1,69 @@ +"""The owl-generator Evaluator stage (§3.5) — deterministic Stage-1 checks +feeding retry_hints, with a bounded retry cap.""" + +from agents.agent_owl_generator import engine as owl_engine + +_CLEAN_TTL = """@prefix owl: . +@prefix rdf: . +@prefix rdfs: . +@prefix xsd: . +@prefix : . + + a owl:Ontology . + +:Customer a owl:Class ; rdfs:label "Customer" . +:Order a owl:Class ; rdfs:label "Order" . + +:placesOrder a owl:ObjectProperty ; rdfs:domain :Customer ; rdfs:range :Order . +:firstName a owl:DatatypeProperty ; rdfs:domain :Customer ; rdfs:range xsd:string . +:orderDate a owl:DatatypeProperty ; rdfs:domain :Order ; rdfs:range xsd:string . +""" + +_ORPHAN_TTL = _CLEAN_TTL + """ +:Ghost a owl:Class ; rdfs:label "Ghost" . +""" + + +def test_clean_ontology_returns_no_retry_hint(): + assert owl_engine._evaluate_ontology_stage(_CLEAN_TTL, {}, 1) is None + + +def test_orphan_class_yields_retry_hint(): + hint = owl_engine._evaluate_ontology_stage(_ORPHAN_TTL, {}, 1) + assert hint is not None + assert "Ghost" in hint + assert "orphan" in hint.lower() + + +_PROSE_PREFIXED = ( + "No database tables are available. I have what I need from the guidelines.\n\n" + + _ORPHAN_TTL +) + +_FENCED = "```turtle\n" + _ORPHAN_TTL + "```" + + +def test_prose_preamble_is_stripped_before_parsing(): + # Regression: the model sometimes prepends a sentence before @prefix. The + # evaluator must clean it (like the downstream registry) and still run, + # not skip. Found via a live Chrome DevTools generation run. + hint = owl_engine._evaluate_ontology_stage(_PROSE_PREFIXED, {}, 1) + assert hint is not None + assert "Ghost" in hint + + +def test_markdown_fenced_turtle_is_parsed(): + hint = owl_engine._evaluate_ontology_stage(_FENCED, {}, 1) + assert hint is not None + assert "Ghost" in hint + + +def test_parse_error_fails_open(): + # Garbage in -> None (never blocks OWL delivery). + assert owl_engine._evaluate_ontology_stage("not turtle at all {{{", {}, 1) is None + + +def test_evaluator_loop_is_bounded(): + # The Evaluator retry cap exists and is finite (real PGE discipline). + assert owl_engine.MAX_OWL_EVAL_ROUNDS >= 1 + assert owl_engine.MAX_OWL_EVAL_ROUNDS < 10 From 5f33e1198192a60eec6ba3ba2a72b6e2e88869cb Mon Sep 17 00:00:00 2001 From: Fiifi Botchway Date: Thu, 25 Jun 2026 12:31:49 +0100 Subject: [PATCH 2/4] feat(mapping): add PGE loop for entity/relationship mapping MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce agent_mapping_pge — a Planner→Generator→Evaluator mapping engine — additively. Plans a source-model, generates entity + relationship SQL per ontology item, and gates each with a deterministic evaluator + a semantic critic. Coverage is engine-enforced from the ontology (abstract-superclass UNION derivation + synthetic-endpoint fallback so one failed hub can't drop all relationships). Additive: agent_auto_assignment is retained and still reachable via AgentClient.run_auto_assignment; a new AgentClient.run_mapping_pge gateway exposes the PGE engine, so an orchestrator can choose between them. Upstream features preserved (ToolContext.warehouse_id, Mapping._canonicalize_imported_uris). - NEW agent_mapping_pge package + tools/{planner,evaluation}.py - context.py: +source_model/+semantic_eval_report fields (warehouse_id kept) - Mapping.py: run PGE + accumulate source_model/evaluations/run_log; save_mappings_to_session gains 3 optional params (legacy path unaffected) - Tests: 90 in tests/agents/agent_mapping_pge; 208 across units/{agents,mapping} Co-authored-by: Isaac --- changelogs/v0.5.2/FiifiB_2026-06-25.log | 57 + src/agents/agent_mapping_pge/__init__.py | 50 + src/agents/agent_mapping_pge/contracts.py | 344 ++++ src/agents/agent_mapping_pge/coverage.py | 353 ++++ src/agents/agent_mapping_pge/engine.py | 1491 +++++++++++++++++ .../agent_mapping_pge/evaluator/__init__.py | 18 + .../agent_mapping_pge/evaluator/critic.py | 747 +++++++++ .../evaluator/deterministic.py | 539 ++++++ .../agent_mapping_pge/evaluator/report.py | 37 + .../agent_mapping_pge/generators/__init__.py | 31 + .../agent_mapping_pge/generators/entity.py | 812 +++++++++ .../generators/relationship.py | 875 ++++++++++ src/agents/agent_mapping_pge/planner.py | 729 ++++++++ src/agents/tools/context.py | 15 +- src/agents/tools/evaluation.py | 205 +++ src/agents/tools/mapping.py | 78 +- src/agents/tools/planner.py | 707 ++++++++ src/back/core/agents/AgentClient.py | 62 + src/back/objects/mapping/Mapping.py | 62 +- tests/agents/agent_mapping_pge/__init__.py | 0 .../agent_mapping_pge/test_contracts.py | 126 ++ .../agents/agent_mapping_pge/test_coverage.py | 140 ++ tests/agents/agent_mapping_pge/test_critic.py | 684 ++++++++ .../test_deterministic_evaluator.py | 916 ++++++++++ tests/agents/agent_mapping_pge/test_engine.py | 1117 ++++++++++++ .../test_entity_generator.py | 662 ++++++++ .../agents/agent_mapping_pge/test_planner.py | 519 ++++++ .../test_relationship_generator.py | 736 ++++++++ 28 files changed, 12104 insertions(+), 8 deletions(-) create mode 100644 changelogs/v0.5.2/FiifiB_2026-06-25.log create mode 100644 src/agents/agent_mapping_pge/__init__.py create mode 100644 src/agents/agent_mapping_pge/contracts.py create mode 100644 src/agents/agent_mapping_pge/coverage.py create mode 100644 src/agents/agent_mapping_pge/engine.py create mode 100644 src/agents/agent_mapping_pge/evaluator/__init__.py create mode 100644 src/agents/agent_mapping_pge/evaluator/critic.py create mode 100644 src/agents/agent_mapping_pge/evaluator/deterministic.py create mode 100644 src/agents/agent_mapping_pge/evaluator/report.py create mode 100644 src/agents/agent_mapping_pge/generators/__init__.py create mode 100644 src/agents/agent_mapping_pge/generators/entity.py create mode 100644 src/agents/agent_mapping_pge/generators/relationship.py create mode 100644 src/agents/agent_mapping_pge/planner.py create mode 100644 src/agents/tools/evaluation.py create mode 100644 src/agents/tools/planner.py create mode 100644 tests/agents/agent_mapping_pge/__init__.py create mode 100644 tests/agents/agent_mapping_pge/test_contracts.py create mode 100644 tests/agents/agent_mapping_pge/test_coverage.py create mode 100644 tests/agents/agent_mapping_pge/test_critic.py create mode 100644 tests/agents/agent_mapping_pge/test_deterministic_evaluator.py create mode 100644 tests/agents/agent_mapping_pge/test_engine.py create mode 100644 tests/agents/agent_mapping_pge/test_entity_generator.py create mode 100644 tests/agents/agent_mapping_pge/test_planner.py create mode 100644 tests/agents/agent_mapping_pge/test_relationship_generator.py diff --git a/changelogs/v0.5.2/FiifiB_2026-06-25.log b/changelogs/v0.5.2/FiifiB_2026-06-25.log new file mode 100644 index 00000000..aa803cfa --- /dev/null +++ b/changelogs/v0.5.2/FiifiB_2026-06-25.log @@ -0,0 +1,57 @@ +# 2026-06-25 — feat(mapping): PGE loop for entity/relationship mapping + +## Context + +Entity/relationship mapping previously ran through `agent_auto_assignment` — +a single-agent "implementer marks its own homework" loop with no planning or +independent evaluation. This change introduces `agent_mapping_pge`, a +Planner→Generator→Evaluator (PGE) mapping engine, **additively**: the original +`agent_auto_assignment` engine is retained and still reachable via +`AgentClient.run_auto_assignment`, so a downstream orchestrator can choose which +engine to run. + +The PGE engine plans a source-model, generates entity and relationship SQL per +ontology item, and gates each with a deterministic evaluator + a semantic +critic. Coverage is engine-enforced (computed from the ontology, not left to LLM +discretion), with abstract-superclass UNION derivation and a synthetic-endpoint +fallback so a single failed hub never cascades to drop all relationships. + +## Changes + +1. NEW package `src/agents/agent_mapping_pge/` — Planner (`planner.py`), + generators (`generators/{entity,relationship}.py`), evaluator + (`evaluator/{deterministic,critic,report}.py`), engine orchestrator + (`engine.py`, bounded ThreadPool walk + monotonic progress), `contracts.py` + (SourceModel/EvalReport), and `coverage.py` (deterministic ontology-derived + coverage; `skip[]` is advisory and never removes an item). +2. NEW `src/agents/tools/planner.py` + `src/agents/tools/evaluation.py` — + planner/evaluation terminal tools (submit_source_model, submit_evaluation, + normalized_value_overlap) used by the PGE agents. +3. `src/agents/tools/context.py` — ADD `source_model` + `semantic_eval_report` + fields (forward-ref typed to avoid a circular import). `warehouse_id` and all + existing fields are preserved. +4. `src/agents/tools/mapping.py` — additive PGE tool-schema plumbing + (`unmapped_attributes`, `MAPPING_TOOL_DEFINITIONS_BY_NAME`). +5. `src/back/core/agents/AgentClient.py` — ADD `run_mapping_pge()` gateway + (→ `agent_mapping_pge`). `run_auto_assignment()` is unchanged and still + points at `agent_auto_assignment` (the simple engine is retained). +6. `src/back/objects/mapping/Mapping.py` — run the PGE engine in the auto-assign + flow and accumulate the PGE extras (`source_model`, `mapping_evaluations`, + `mapping_run_log`) across chunks and single-item runs; + `save_mappings_to_session` gains three OPTIONAL params (default `None`, so the + legacy path is unaffected). The upstream `_canonicalize_imported_uris` helper + is preserved. +7. Tests: `tests/agents/agent_mapping_pge/` — contracts, coverage, planner, + entity/relationship generators, deterministic evaluator, critic, engine. + +## Modified / added files + +27 files changed, 12047 insertions(+), 8 deletions(-). New `agent_mapping_pge` +package (12 modules) + 2 new tools + 9 test modules; 4 additive modifications +(`context.py`, `mapping.py`, `AgentClient.py`, `Mapping.py`). + +## Tests + +- `uv run pytest tests/agents/agent_mapping_pge -q` → **90 passed**. +- `uv run pytest tests/units/agents tests/units/mapping -q` → **208 passed**. +- Imports resolve on the upstream base (origin/master, v0.5.2). diff --git a/src/agents/agent_mapping_pge/__init__.py b/src/agents/agent_mapping_pge/__init__.py new file mode 100644 index 00000000..0da0713e --- /dev/null +++ b/src/agents/agent_mapping_pge/__init__.py @@ -0,0 +1,50 @@ +"""Planner -> Generator -> Evaluator (PGE) mapping agent. + +Three-stage mapping pipeline that replaces the prior single-loop ReAct +mapping agent: + +* **Planner** — proposes a :class:`SourceModel` (table roles, canonical IDs, + join keys, ordered mapping plan). +* **Generator** — produces individual entity/relationship mappings given the + plan. +* **Evaluator** — checks each submitted mapping; stage 1 is deterministic + (pure SQL counts), stage 2 is semantic. + +Sprint 1 lays the foundation: the typed contracts plus the deterministic +evaluator. Subsequent sprints add the LLM-backed Planner, Generator, +semantic Evaluator, and the orchestrating loop. +""" + +from agents.agent_mapping_pge.contracts import ( + CanonicalId, + EvalFailure, + EvalReport, + JoinKey, + MappingPlan, + RetryState, + SkipItem, + SourceModel, + TableRole, + TableRoleCandidate, +) +from agents.agent_mapping_pge.engine import ( + AgentResult, + AgentStep, + run_agent, +) + +__all__ = [ + "AgentResult", + "AgentStep", + "CanonicalId", + "EvalFailure", + "EvalReport", + "JoinKey", + "MappingPlan", + "RetryState", + "SkipItem", + "SourceModel", + "TableRole", + "TableRoleCandidate", + "run_agent", +] diff --git a/src/agents/agent_mapping_pge/contracts.py b/src/agents/agent_mapping_pge/contracts.py new file mode 100644 index 00000000..8172e431 --- /dev/null +++ b/src/agents/agent_mapping_pge/contracts.py @@ -0,0 +1,344 @@ +"""Typed contracts for the mapping PGE pipeline. + +These dataclasses are the load-bearing interface between Planner, Generator, +Evaluator, and the orchestrator (added in later sprints). All shapes here +are JSON round-trippable via ``to_dict()`` / ``from_dict()`` so they can be +persisted as artefacts, attached to MLflow traces, or shipped over the wire +to the UI. + +No LLM code lives here; this is a pure-data module. +""" + +from dataclasses import dataclass, field, fields, is_dataclass +from typing import Any, Dict, List, Optional + + +# ===================================================== +# SourceModel — Planner output +# ===================================================== + + +@dataclass +class TableRoleCandidate: + """A candidate ontology class for a given source table.""" + + uri: str + confidence: float # 0.0 .. 1.0 + reason: str = "" + + def to_dict(self) -> Dict[str, Any]: + return {"uri": self.uri, "confidence": self.confidence, "reason": self.reason} + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "TableRoleCandidate": + return cls( + uri=data["uri"], + confidence=float(data["confidence"]), + reason=data.get("reason", ""), + ) + + +@dataclass +class TableRole: + """A source table together with its ranked ontology-class candidates.""" + + table: str # full name catalog.schema.table + ontology_class_candidates: List[TableRoleCandidate] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + "table": self.table, + "ontology_class_candidates": [ + c.to_dict() for c in self.ontology_class_candidates + ], + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "TableRole": + return cls( + table=data["table"], + ontology_class_candidates=[ + TableRoleCandidate.from_dict(c) + for c in data.get("ontology_class_candidates", []) + ], + ) + + +@dataclass +class CanonicalId: + """Identifier conventions for an ontology class across its source tables. + + ``canonical_column_per_table`` maps a full table name -> the column to + use as the canonical identifier in that table (e.g. NHS number rather + than the trust-local patient id). + """ + + ontology_class: str # class URI + canonical_column_per_table: Dict[str, str] = field(default_factory=dict) + format_note: str = "" + + def to_dict(self) -> Dict[str, Any]: + return { + "ontology_class": self.ontology_class, + "canonical_column_per_table": dict(self.canonical_column_per_table), + "format_note": self.format_note, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "CanonicalId": + return cls( + ontology_class=data["ontology_class"], + canonical_column_per_table=dict( + data.get("canonical_column_per_table", {}) + ), + format_note=data.get("format_note", ""), + ) + + +@dataclass +class JoinKey: + """A proposed join between two table.column references. + + ``kind`` distinguishes within-trust foreign keys from value-matched + cross-source joins (e.g. NHS-number-to-NHS-number across trusts). + """ + + from_ref: str # "table.col" + to_ref: str # "table.col" + confidence: float # 0..1 + overlap_pct: float # 0..1 + kind: str # "same_trust_fk" | "cross_source_value_match" + + def to_dict(self) -> Dict[str, Any]: + return { + "from_ref": self.from_ref, + "to_ref": self.to_ref, + "confidence": self.confidence, + "overlap_pct": self.overlap_pct, + "kind": self.kind, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "JoinKey": + return cls( + from_ref=data["from_ref"], + to_ref=data["to_ref"], + confidence=float(data["confidence"]), + overlap_pct=float(data["overlap_pct"]), + kind=data["kind"], + ) + + +@dataclass +class SkipItem: + """An ontology entity/relationship the planner has decided to skip.""" + + item: str # uri + reason: str = "" + + def to_dict(self) -> Dict[str, Any]: + return {"item": self.item, "reason": self.reason} + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SkipItem": + return cls(item=data["item"], reason=data.get("reason", "")) + + +@dataclass +class MappingPlan: + """The order in which the Generator should attempt entity/relationship + mappings, plus any items the planner chose to drop.""" + + entity_order: List[str] = field(default_factory=list) + relationship_order: List[str] = field(default_factory=list) + skip: List[SkipItem] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + "entity_order": list(self.entity_order), + "relationship_order": list(self.relationship_order), + "skip": [s.to_dict() for s in self.skip], + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MappingPlan": + return cls( + entity_order=list(data.get("entity_order", [])), + relationship_order=list(data.get("relationship_order", [])), + skip=[SkipItem.from_dict(s) for s in data.get("skip", [])], + ) + + +@dataclass +class SourceModel: + """Output of the Planner stage; input to the Generator. + + Contains the planner's understanding of the source schema (table roles, + canonical ids, join keys) and the ordered plan of work for the + Generator. + """ + + table_roles: List[TableRole] = field(default_factory=list) + canonical_ids: List[CanonicalId] = field(default_factory=list) + join_keys: List[JoinKey] = field(default_factory=list) + mapping_plan: MappingPlan = field(default_factory=MappingPlan) + + def to_dict(self) -> Dict[str, Any]: + return { + "table_roles": [t.to_dict() for t in self.table_roles], + "canonical_ids": [c.to_dict() for c in self.canonical_ids], + "join_keys": [j.to_dict() for j in self.join_keys], + "mapping_plan": self.mapping_plan.to_dict(), + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SourceModel": + return cls( + table_roles=[ + TableRole.from_dict(t) for t in data.get("table_roles", []) + ], + canonical_ids=[ + CanonicalId.from_dict(c) for c in data.get("canonical_ids", []) + ], + join_keys=[JoinKey.from_dict(j) for j in data.get("join_keys", [])], + mapping_plan=MappingPlan.from_dict(data.get("mapping_plan", {})), + ) + + +# ===================================================== +# EvalReport — Evaluator output +# ===================================================== + + +@dataclass +class EvalFailure: + """A single failed check inside an :class:`EvalReport`. + + ``hint`` is the actionable correction text fed back to the Generator on + retry; it should be concrete and template-y, not a free-form essay. + """ + + kind: str # "structural" | "semantic" + check: str # e.g. "dangling_source_pct" + expected: str # e.g. "< 0.05" + observed: str # e.g. "0.47" + hint: str = "" + + def to_dict(self) -> Dict[str, Any]: + return { + "kind": self.kind, + "check": self.check, + "expected": self.expected, + "observed": self.observed, + "hint": self.hint, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "EvalFailure": + return cls( + kind=data["kind"], + check=data["check"], + expected=data["expected"], + observed=data["observed"], + hint=data.get("hint", ""), + ) + + +@dataclass +class EvalReport: + """Outcome of evaluating a single submitted mapping. + + ``bubble_to_planner`` signals that the failure cannot reasonably be + fixed by the Generator alone and warrants re-planning (e.g. wrong + canonical id column, table assigned to wrong ontology class). + """ + + status: str # "PASS" | "FAIL" + stage: str # "deterministic" | "semantic" + metrics: Dict[str, Any] = field(default_factory=dict) + failures: List[EvalFailure] = field(default_factory=list) + bubble_to_planner: bool = False + + def to_dict(self) -> Dict[str, Any]: + return { + "status": self.status, + "stage": self.stage, + "metrics": dict(self.metrics), + "failures": [f.to_dict() for f in self.failures], + "bubble_to_planner": self.bubble_to_planner, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "EvalReport": + return cls( + status=data["status"], + stage=data["stage"], + metrics=dict(data.get("metrics", {})), + failures=[EvalFailure.from_dict(f) for f in data.get("failures", [])], + bubble_to_planner=bool(data.get("bubble_to_planner", False)), + ) + + +# ===================================================== +# RetryState — orchestrator bookkeeping (used in Sprint 7) +# ===================================================== + + +@dataclass +class RetryState: + """Per-item retry budget tracked by the orchestrator. + + The orchestrator caps the Generator at 3 attempts per item before + giving up, and bumps the Planner at most twice per item if the + evaluator keeps bubbling failures upstream. + """ + + item_uri: str + generator_attempts: int = 0 + planner_reinvocations: int = 0 + last_eval_report: Optional[EvalReport] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "item_uri": self.item_uri, + "generator_attempts": self.generator_attempts, + "planner_reinvocations": self.planner_reinvocations, + "last_eval_report": ( + self.last_eval_report.to_dict() + if self.last_eval_report is not None + else None + ), + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "RetryState": + last = data.get("last_eval_report") + return cls( + item_uri=data["item_uri"], + generator_attempts=int(data.get("generator_attempts", 0)), + planner_reinvocations=int(data.get("planner_reinvocations", 0)), + last_eval_report=EvalReport.from_dict(last) if last is not None else None, + ) + + +# ===================================================== +# Sanity check — keep dataclass discovery introspectable +# ===================================================== + +_ALL_CONTRACTS = ( + TableRoleCandidate, + TableRole, + CanonicalId, + JoinKey, + SkipItem, + MappingPlan, + SourceModel, + EvalFailure, + EvalReport, + RetryState, +) +for _cls in _ALL_CONTRACTS: + assert is_dataclass(_cls), f"{_cls.__name__} must be a dataclass" + # touch ``fields`` to ensure all defaults are well-formed at import time. + fields(_cls) +del _cls diff --git a/src/agents/agent_mapping_pge/coverage.py b/src/agents/agent_mapping_pge/coverage.py new file mode 100644 index 00000000..9277a3f2 --- /dev/null +++ b/src/agents/agent_mapping_pge/coverage.py @@ -0,0 +1,353 @@ +"""Deterministic coverage enforcement + derived-mapping construction. + +The Planner is good at the *how* (which column is the canonical id, which +tables join, in what order to attempt things) but it must NOT be trusted with +the *what*: the set of entities and relationships that need mapping is fixed — +it is exactly the input ontology. Leaving coverage to the LLM produced +non-deterministic partial runs (some classes silently dropped into +``mapping_plan.skip``) and, because relationships were skipped whenever an +endpoint entity was absent, zero relationship coverage. + +This module computes the FULL, dependency-ordered coverage set from the +ontology itself (using the Planner's order only as a hint), and builds the two +kinds of mapping the LLM Generators cannot/should not produce: + +* **Abstract-superclass mappings** — a class with subclasses but no source + table of its own (``Person``, ``Patient``, ``Clinicalencounter``, + ``Clinicalfinding``). Its instances are exactly the UNION of its concrete + leaf subclasses, so we derive it mechanically from the already-validated + subclass mappings. Reusing the subclasses' verbatim ``sql_query`` keeps the + abstract id-universe byte-identical to the union of its parts, which is what + makes relationships pointing at the abstract (e.g. ``managedby`` with domain + ``Clinicalencounter``) join with zero dangling. + +* **Synthetic endpoint id-universe** — when a relationship endpoint entity has + no full mapping yet, we synthesise a minimal ``{sql_query, id_column}`` from + the Planner's ``canonical_ids`` so the relationship can still be attempted + instead of silently skipped. + +All functions here are pure data transforms — no LLM, no I/O — so they are +fast and unit-testable. +""" + +from typing import Dict, List, Optional, Set, Tuple + +from back.core.logging import get_logger +from agents.agent_mapping_pge.contracts import SourceModel + +logger = get_logger(__name__) + + +# ===================================================== +# Ontology structure helpers +# ===================================================== + + +def _entities(ontology: dict) -> List[dict]: + return (ontology or {}).get("entities", []) or [] + + +def _relationships(ontology: dict) -> List[dict]: + return (ontology or {}).get("relationships", []) or [] + + +def _uri(entity: dict) -> str: + return entity.get("uri") or entity.get("name") or "" + + +def name_to_uri(ontology: dict) -> Dict[str, str]: + """Map both short name and label to URI for parent/domain/range resolution.""" + out: Dict[str, str] = {} + for e in _entities(ontology): + uri = _uri(e) + if not uri: + continue + out[uri] = uri + if e.get("name"): + out[e["name"]] = uri + if e.get("label"): + out[e["label"]] = uri + return out + + +def parent_uri(entity: dict, n2u: Dict[str, str]) -> Optional[str]: + """Resolve a class's parent (stored as a name/label/uri) to a URI.""" + p = (entity.get("parent") or "").strip() + if not p: + return None + return n2u.get(p, p if p.startswith("http") else None) + + +def _tables_for_class(source_model: Optional[SourceModel]) -> Set[str]: + """URIs that the Planner assigned at least one source table to.""" + if source_model is None: + return set() + out: Set[str] = set() + for role in source_model.table_roles: + for cand in role.ontology_class_candidates: + out.add(cand.uri) + return out + + +def classify( + ontology: dict, + source_model: Optional[SourceModel], + *, + synthesized_uris: Optional[Set[str]] = None, +) -> Tuple[Set[str], Set[str]]: + """Partition classes into (concrete, abstract). + + A class is **abstract/derived** when it has subclasses but no source table + of its own — its rows are the union of its concrete descendants. Every + other class is **concrete** (it has, or will have via synthesis, a source + table and gets a normal Generator mapping). + """ + n2u = name_to_uri(ontology) + has_children: Set[str] = set() + for e in _entities(ontology): + p = parent_uri(e, n2u) + if p: + has_children.add(p) + + has_table = _tables_for_class(source_model) | (synthesized_uris or set()) + + concrete: Set[str] = set() + abstract: Set[str] = set() + for e in _entities(ontology): + uri = _uri(e) + if not uri: + continue + if uri in has_children and uri not in has_table: + abstract.add(uri) + else: + concrete.add(uri) + return concrete, abstract + + +def concrete_leaf_descendants( + abstract_uri: str, ontology: dict, concrete: Set[str] +) -> List[str]: + """All concrete descendant class URIs beneath ``abstract_uri`` (transitive).""" + n2u = name_to_uri(ontology) + children_of: Dict[str, List[str]] = {} + for e in _entities(ontology): + p = parent_uri(e, n2u) + if p: + children_of.setdefault(p, []).append(_uri(e)) + + out: List[str] = [] + stack = list(children_of.get(abstract_uri, [])) + seen: Set[str] = set() + while stack: + cur = stack.pop() + if cur in seen: + continue + seen.add(cur) + if cur in concrete: + out.append(cur) + stack.extend(children_of.get(cur, [])) + return out + + +# ===================================================== +# Coverage ordering (engine-enforced — the "what") +# ===================================================== + + +def full_entity_order( + ontology: dict, + source_model: Optional[SourceModel], + *, + synthesized_uris: Optional[Set[str]] = None, +) -> List[str]: + """Complete entity order: every class, concrete-first, abstracts after + their descendants. + + Uses the Planner's ``entity_order`` only to order concrete classes (it + knows base-vs-referencer dependencies); abstracts are appended in + descendant-before-ancestor order so a derived union can read its parts. + """ + concrete, abstract = classify( + ontology, source_model, synthesized_uris=synthesized_uris + ) + planned = ( + list(source_model.mapping_plan.entity_order) if source_model else [] + ) + + ordered: List[str] = [] + seen: Set[str] = set() + + # 1. Concrete classes in the Planner's order first… + for uri in planned: + if uri in concrete and uri not in seen: + ordered.append(uri) + seen.add(uri) + # 2. …then any concrete class the Planner omitted (coverage guarantee). + for e in _entities(ontology): + uri = _uri(e) + if uri in concrete and uri not in seen: + ordered.append(uri) + seen.add(uri) + + # 3. Abstracts, ordered so each appears after all its descendants. + remaining = [u for u in abstract if u not in seen] + + def _depth(uri: str) -> int: + # number of concrete leaves — deeper subtrees (more leaves) last is + # fine; what matters is a class never precedes its own descendant. + return len(concrete_leaf_descendants(uri, ontology, concrete)) + + # Topologically: a class with fewer abstract-ancestors first. Simple stable + # approach: repeatedly emit abstracts whose abstract-children are all done. + n2u = name_to_uri(ontology) + abstract_children: Dict[str, List[str]] = {} + for e in _entities(ontology): + uri = _uri(e) + if uri in abstract: + p = parent_uri(e, n2u) + if p in abstract: + abstract_children.setdefault(p, []).append(uri) + + progress = True + while remaining and progress: + progress = False + for uri in list(remaining): + kids = abstract_children.get(uri, []) + if all(k in seen for k in kids): + ordered.append(uri) + seen.add(uri) + remaining.remove(uri) + progress = True + # Anything left (cycles — shouldn't happen) just append. + for uri in remaining: + ordered.append(uri) + seen.add(uri) + + return ordered + + +def full_relationship_order( + ontology: dict, entity_order: List[str], source_model: Optional[SourceModel] +) -> List[str]: + """Every object property, ordered so both endpoints precede it where + possible (falls back to Planner order / declaration order).""" + rels = _relationships(ontology) + rel_uris = [r.get("uri") or r.get("name") for r in rels] + rel_uris = [u for u in rel_uris if u] + + planned = ( + list(source_model.mapping_plan.relationship_order) if source_model else [] + ) + ordered: List[str] = [] + seen: Set[str] = set() + for uri in planned + rel_uris: + if uri and uri not in seen: + ordered.append(uri) + seen.add(uri) + return ordered + + +# ===================================================== +# Derived mappings (the "how" the LLM cannot produce) +# ===================================================== + +_ID = "ID" + + +def _attr_names(entity: dict) -> List[str]: + out: List[str] = [] + for a in entity.get("attributes", []) or []: + if isinstance(a, dict): + n = a.get("name") + if n: + out.append(str(n)) + elif a is not None: + out.append(str(a)) + return out + + +def build_abstract_union_mapping( + abstract_uri: str, + abstract_entity: dict, + subclass_mappings: List[dict], +) -> Optional[dict]: + """Build a derived entity mapping for an abstract class as the UNION ALL of + its concrete subclass mappings. + + Reuses each subclass's verbatim ``sql_query`` (wrapped in a subquery) so the + abstract id-universe equals the union of the parts exactly. Projects the + abstract class's own attributes by re-aliasing each subclass's + ``attribute_mappings`` value to the ontology attribute name; subclasses that + do not carry an attribute contribute ``NULL`` for it. + """ + subs = [m for m in subclass_mappings if m and m.get("sql_query")] + if not subs: + return None + + attrs = _attr_names(abstract_entity) + selects: List[str] = [] + for m in subs: + amap = m.get("attribute_mappings") or {} + cols = [_ID] + for attr in attrs: + src_alias = amap.get(attr) + if src_alias: + cols.append(f"{src_alias} AS {attr}") + else: + cols.append(f"CAST(NULL AS STRING) AS {attr}") + selects.append(f"SELECT {', '.join(cols)} FROM ({m['sql_query']}) ") + union = " UNION ALL ".join(selects) + # DISTINCT on the whole projection collapses any incidental duplicates while + # preserving distinct ids (subclass id spaces are disjoint by construction). + sql = f"SELECT DISTINCT * FROM ({union}) _abstract WHERE {_ID} IS NOT NULL" + + return { + "class_uri": abstract_uri, + "ontology_class": abstract_uri, + "class_name": abstract_entity.get("name", abstract_uri), + "sql_query": sql, + "id_column": _ID, + "label_column": _ID, + "attribute_mappings": {a: a for a in attrs}, + "unmapped_attributes": [], + "derived": "abstract_union", + } + + +def synthetic_endpoint_mapping( + source_model: Optional[SourceModel], class_uri: str +) -> Optional[dict]: + """Build a minimal id-universe-only entity mapping for a relationship + endpoint from the Planner's ``canonical_ids``. + + Used as a fallback so a relationship is never skipped just because its + endpoint entity lacks a full mapping. Produces a UNION ALL of + ``SELECT AS ID FROM `` over every table the + Planner recorded for the class. + """ + if source_model is None: + return None + cid = next( + (c for c in source_model.canonical_ids if c.ontology_class == class_uri), + None, + ) + if cid is None or not cid.canonical_column_per_table: + return None + selects = [ + f"SELECT {expr} AS {_ID} FROM {table}" + for table, expr in cid.canonical_column_per_table.items() + if expr and table + ] + if not selects: + return None + inner = " UNION ALL ".join(selects) + sql = f"SELECT DISTINCT {_ID} FROM ({inner}) _u WHERE {_ID} IS NOT NULL" + return { + "class_uri": class_uri, + "ontology_class": class_uri, + "sql_query": sql, + "id_column": _ID, + "attribute_mappings": {}, + "unmapped_attributes": [], + "derived": "synthetic_endpoint", + } diff --git a/src/agents/agent_mapping_pge/engine.py b/src/agents/agent_mapping_pge/engine.py new file mode 100644 index 00000000..d6245b2a --- /dev/null +++ b/src/agents/agent_mapping_pge/engine.py @@ -0,0 +1,1491 @@ +""" +OntoBricks Mapping-PGE Orchestrator. + +Wires the Planner, the Entity/Relationship Generators, and the two-stage +Evaluator (deterministic + semantic critic) into a single ``run_agent`` +entry point. + +The public ``run_agent`` signature and :class:`AgentResult` shape match the +prior in-house single-loop mapping agent so ``back/objects/mapping/Mapping.py`` +can call this engine without other changes. + +Control flow per item (entity or relationship) +============================================== + +1. Build a focused slice from the Planner's :class:`SourceModel`. +2. Run the appropriate Generator with ``retry_hint=None``. +3. Run the deterministic evaluator. On FAIL: + * if ``bubble_to_planner=True`` -> escalate to Planner (capped at 2 global + replans across the whole run); + * else retry the Generator with the first failure's hint. +4. On stage-1 PASS, run the semantic critic (unless ``skip_semantic_critic`` + is set). Same bubble / hint logic on FAIL. +5. After 3 unsuccessful attempts, the item is recorded as ``FAIL_BUDGET`` and + the orchestrator moves on to the next item. + +Step-log design +=============== + +``AgentResult.steps`` is a HIGH-LEVEL log — one entry per stage transition +(planner-start, generator-start, evaluator-result, critic-result, item-done). +The detailed per-tool steps emitted by each sub-agent stay on the sub-agent's +own result dataclass (``PlannerResult.steps``, ``EntityGenResult.steps``, …) +and are NOT merged into the orchestrator's ``steps`` field. This keeps the +top-level log readable in the UI; the persistence layer can attach sub-agent +step lists separately when needed. +""" + +import concurrent.futures +import threading +import time +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Tuple + +from back.core.logging import get_logger +from agents.agent_mapping_pge.contracts import EvalReport, SourceModel +from agents.agent_mapping_pge.coverage import ( + build_abstract_union_mapping, + classify, + concrete_leaf_descendants, + full_entity_order, + full_relationship_order, + synthetic_endpoint_mapping, +) +from agents.agent_mapping_pge.evaluator.critic import run_critic +from agents.agent_mapping_pge.evaluator.deterministic import ( + evaluate_entity_mapping, + evaluate_relationship_mapping, +) +from agents.agent_mapping_pge.generators.entity import run_entity_generator +from agents.agent_mapping_pge.generators.relationship import ( + run_relationship_generator, +) +from agents.agent_mapping_pge.planner import run_planner +from agents.tracing import trace_agent + +logger = get_logger(__name__) + +# Per-item retry budget for the Generator->Evaluator inner loop. +_PER_ITEM_GENERATOR_ATTEMPTS = 4 +# Global cap on Planner re-invocations triggered by escalated failures. +_PLANNER_REINVOCATION_BUDGET = 3 + +# Bounded parallelism for per-item Generator->Evaluator work. Items are mutually +# independent and the SQL client is connection-pooled + thread-safe, so this is +# the single biggest wall-clock win. Kept modest to respect FM-endpoint rate +# limits (call_llm_with_retry handles any 429 backoff). +_MAX_CONCURRENCY = 4 + + +# ===================================================== +# Public dataclasses — mirror the prior mapping agent's shapes +# ===================================================== + + +@dataclass +class AgentStep: + """One observable step of the orchestrator's execution. + + Same shape as :class:`agents.engine_base.AgentStep` plus a few extra + ``step_type`` values used by the PGE orchestrator: + + * ``"planner"`` / ``"generator"`` / ``"evaluator"`` / ``"critic"`` for + stage transitions; the legacy ``"tool_call"`` / ``"tool_result"`` / + ``"output"`` types remain valid so this struct is fully drop-in- + compatible with the prior orchestrator. + """ + + step_type: str + content: str + tool_name: str = "" + duration_ms: int = 0 + + +@dataclass +class AgentResult: + """Outcome of a full PGE orchestration run. + + The first eight fields mirror the prior in-house mapping-agent's result + dataclass exactly so callers can swap engines without touching their + downstream code. The last three are PGE-specific extras the caller + can choose to persist. + """ + + success: bool + entity_mappings: list = field(default_factory=list) + relationship_mappings: list = field(default_factory=list) + steps: List[AgentStep] = field(default_factory=list) + iterations: int = 0 + error: str = "" + usage: Dict[str, int] = field(default_factory=dict) + stats: Dict[str, int] = field(default_factory=dict) + # PGE-specific extras + source_model: Optional[dict] = None + mapping_evaluations: Dict[str, dict] = field(default_factory=dict) + mapping_run_log: List[dict] = field(default_factory=list) + + +# ===================================================== +# Internal helpers +# ===================================================== + + +def _ontology_index(ontology: dict) -> Dict[str, dict]: + """Build ``uri -> entity dict`` for fast lookup by URI.""" + out: Dict[str, dict] = {} + for e in (ontology or {}).get("entities", []) or []: + uri = e.get("uri") or e.get("name") + if uri: + out[uri] = e + return out + + +def _relationship_index(ontology: dict) -> Dict[str, dict]: + """Build ``uri -> relationship dict`` for fast lookup by URI.""" + out: Dict[str, dict] = {} + for r in (ontology or {}).get("relationships", []) or []: + uri = r.get("uri") or r.get("name") + if uri: + out[uri] = r + return out + + +def _slice_for_entity(source_model: SourceModel, class_uri: str) -> dict: + """Render the SourceModel slice consumed by the EntityGenerator. + + The slice surfaces only what's relevant to one ontology class: + candidate tables, the canonical-ID per chosen table, and any joins + naming a candidate table on at least one side. + """ + candidate_tables: List[dict] = [] + candidate_table_names: set = set() + for role in source_model.table_roles: + for cand in role.ontology_class_candidates: + if cand.uri == class_uri: + candidate_tables.append( + { + "table": role.table, + "confidence": cand.confidence, + "reason": cand.reason, + } + ) + candidate_table_names.add(role.table) + break # one entry per role is enough + + canonical_id_obj: Dict[str, Any] = { + "ontology_class": class_uri, + "canonical_column_per_table": {}, + "format_note": "", + } + for c in source_model.canonical_ids: + if c.ontology_class == class_uri: + canonical_id_obj = c.to_dict() + break + + relevant_joins: List[dict] = [] + for j in source_model.join_keys: + from_table = j.from_ref.split(".")[0] if j.from_ref else "" + to_table = j.to_ref.split(".")[0] if j.to_ref else "" + if any( + ft == from_table or ft.endswith("." + from_table) + for ft in candidate_table_names + ) or any( + tt == to_table or tt.endswith("." + to_table) + for tt in candidate_table_names + ): + relevant_joins.append(j.to_dict()) + + return { + "candidate_tables": candidate_tables, + "canonical_id": canonical_id_obj, + "relevant_joins": relevant_joins, + } + + +def _slice_for_relationship( + source_model: SourceModel, + property_uri: str, + source_entity_mapping: dict, + target_entity_mapping: dict, +) -> dict: + """Render the SourceModel slice consumed by the RelationshipGenerator. + + The slice surfaces every join key the Planner produced (the Generator + picks among them), plus the candidate-table list filtered to the + source/target classes when those classes are known. + """ + src_class = (source_entity_mapping or {}).get("ontology_class") or ( + source_entity_mapping or {} + ).get("class_uri", "") + tgt_class = (target_entity_mapping or {}).get("ontology_class") or ( + target_entity_mapping or {} + ).get("class_uri", "") + endpoint_classes = {c for c in (src_class, tgt_class) if c} + + candidate_tables: List[dict] = [] + for role in source_model.table_roles: + for cand in role.ontology_class_candidates: + if not endpoint_classes or cand.uri in endpoint_classes: + candidate_tables.append( + { + "table": role.table, + "ontology_class": cand.uri, + "confidence": cand.confidence, + "reason": cand.reason, + } + ) + break + + relevant_joins = [j.to_dict() for j in source_model.join_keys] + + return { + "property_uri": property_uri, + "relevant_joins": relevant_joins, + "candidate_tables": candidate_tables, + } + + +def _wrap_execute_sql(client: Any) -> Callable[[str], dict]: + """Adapt ``client.execute_query`` to the evaluator's expected shape. + + The deterministic evaluator wants ``{"columns": [...], "rows": [{...}]}`` + with FULL rows. ``client.execute_query`` returns ``List[Dict[str, Any]]`` + — we promote that to the evaluator's shape and derive columns from the + first row. Calling the underlying client directly (rather than the + sampling ``tool_execute_sql``) is load-bearing: the deterministic + evaluator's count-based checks need real values, not stringified ones. + """ + + def _run(sql: str) -> dict: + rows = client.execute_query(sql) or [] + if isinstance(rows, dict) and "rows" in rows: + return rows # client already returns the evaluator's shape + columns: List[str] = [] + if rows and isinstance(rows[0], dict): + columns = list(rows[0].keys()) + return {"columns": columns, "rows": list(rows)} + + return _run + + +def _first_hint(report: EvalReport) -> Optional[str]: + """Return the first failure's hint (or ``None`` when the report has none).""" + for f in report.failures: + if f.hint: + return f.hint + return None + + +def _resolve_endpoint_em( + ref: str, + by_uri: Dict[str, dict], + entity_index: Dict[str, dict], +) -> Optional[dict]: + """Best-effort lookup of an endpoint entity mapping. + + The ontology's ``domain`` / ``range`` may use either the entity's full + URI or its short name. We try direct lookup, then a name-match scan. + """ + if not ref: + return None + if ref in by_uri: + return by_uri[ref] + for uri, ent in entity_index.items(): + if ent.get("name") == ref or ent.get("label") == ref: + if uri in by_uri: + return by_uri[uri] + return None + + +def _ref_to_uri(ref: str, entity_index: Dict[str, dict]) -> str: + """Resolve a domain/range ref (URI, name, or label) to a class URI.""" + if not ref or ref in entity_index: + return ref + for uri, ent in entity_index.items(): + if ent.get("name") == ref or ent.get("label") == ref: + return uri + return ref + + +def _endpoint_em(state: "_RunState", ref: str) -> Optional[dict]: + """Resolve a relationship endpoint to an entity mapping carrying an + id-universe SQL. + + Prefers the real (fully-mapped) entity mapping; falls back to a synthetic + id-universe built from the Planner's ``canonical_ids`` so a relationship is + never skipped merely because its endpoint entity's attribute mapping + failed. Returns ``None`` only when the class is entirely unknown to both + the mapped set and the source model. + """ + real = state.entity_mapping_by_uri.get(ref) or _resolve_endpoint_em( + ref, state.entity_mapping_by_uri, state.entity_index + ) + if real is not None: + return real + uri = _ref_to_uri(ref, state.entity_index) + return synthetic_endpoint_mapping(state.source_model, uri) + + +# ===================================================== +# Public entry point +# ===================================================== + + +@trace_agent(name="mapping_pge_engine") +def run_agent( + host: str, + token: str, + endpoint_name: str, + client: Any, + metadata: dict, + ontology: dict, + entity_mappings: Optional[list] = None, + relationship_mappings: Optional[list] = None, + documents: Optional[list] = None, + on_step: Optional[Callable[[str, int], None]] = None, + max_iterations: Optional[int] = None, + *, + skip_semantic_critic: bool = False, +) -> AgentResult: + """Run the PGE mapping orchestrator. + + Drop-in replacement for the prior in-house single-loop mapping agent — + same positional/keyword signature, same :class:`AgentResult` shape. + + Args: + host: Databricks workspace URL. + token: Bearer token for the serving endpoint. + endpoint_name: Foundation Model serving endpoint name. + client: Databricks SQL client exposing ``execute_query(sql)``. + metadata: Imported table metadata to hand to the Planner. + ontology: Ontology dict with ``entities`` and ``relationships``. + entity_mappings: Pre-seeded entity mappings (URI matched -> skipped). + relationship_mappings: Pre-seeded relationship mappings (likewise). + documents: Optional pre-loaded domain documents. + on_step: Optional progress callback ``(msg, pct)``. + max_iterations: Per-item override for the Generator's iteration cap. + Kept for API parity with the legacy engine; ``None`` uses each + sub-agent's default. + skip_semantic_critic: When ``True``, the orchestrator skips the + stage-2 critic and accepts every stage-1 PASS as a final PASS. + Production callers leave this ``False``; tests flip it ``True`` + to avoid LLM calls in the orchestrator's unit tests. + + Returns: + An :class:`AgentResult` with the submitted mappings, a high-level + ``steps`` log, per-item ``mapping_run_log``, and PGE-specific + extras (``source_model``, ``mapping_evaluations``). + """ + # ------------------------------------------------------------------ + # Per-call state lives entirely on this RunState object — no module- + # level mutables, so concurrent calls (and tests) cannot collide. + # ------------------------------------------------------------------ + state = _RunState( + host=host, + token=token, + endpoint_name=endpoint_name, + client=client, + metadata=metadata or {}, + ontology=ontology or {}, + documents=list(documents or []), + on_step=on_step, + max_iterations=max_iterations, + skip_semantic_critic=skip_semantic_critic, + ) + + # Pre-seeded mappings carry over verbatim — we never overwrite a URI the + # caller already mapped. + pre_entity_list = list(entity_mappings or []) + pre_rel_list = list(relationship_mappings or []) + preseeded_entity_uris = { + m.get("ontology_class") or m.get("class_uri") or "" for m in pre_entity_list + } + preseeded_entity_uris.discard("") + preseeded_rel_uris = { + m.get("property") or m.get("property_uri") or "" for m in pre_rel_list + } + preseeded_rel_uris.discard("") + + state.entity_mappings.extend(pre_entity_list) + state.relationship_mappings.extend(pre_rel_list) + for m in pre_entity_list: + uri = m.get("ontology_class") or m.get("class_uri") + if uri: + state.entity_mapping_by_uri[uri] = m + + entities_in_scope = state.ontology.get("entities", []) or [] + relationships_in_scope = state.ontology.get("relationships", []) or [] + + logger.info( + "===== MAPPING-PGE ENGINE START ===== endpoint=%s, entities=%d, " + "relationships=%d, preseeded_entities=%d, preseeded_rels=%d, " + "skip_critic=%s", + endpoint_name, + len(entities_in_scope), + len(relationships_in_scope), + len(preseeded_entity_uris), + len(preseeded_rel_uris), + skip_semantic_critic, + ) + + # ------------------------------------------------------------------ + # 1. Planner + # ------------------------------------------------------------------ + state.notify("Planning…", pct=2) + state.add_step("planner", "planner-start") + + t0 = time.time() + try: + planner_result = run_planner( + host=host, + token=token, + endpoint_name=endpoint_name, + client=client, + metadata=state.metadata, + ontology=state.ontology, + documents=state.documents, + on_step=None, + ) + except Exception as exc: # noqa: BLE001 — surface any failure as run failure + logger.error("Planner raised an exception: %s", exc, exc_info=True) + return state.finalise(error=f"planner exception: {exc}") + + planner_ms = int((time.time() - t0) * 1000) + state.add_iterations(planner_result.iterations) + state.accumulate_usage(planner_result.usage) + + if not planner_result.success or planner_result.source_model is None: + state.add_step( + "planner", + f"planner-fail: {planner_result.error}", + duration_ms=planner_ms, + ) + logger.error("===== MAPPING-PGE ENGINE FAILED ===== planner failed") + state.notify("Planner failed — aborting.", pct=10) + return state.finalise( + error=f"planner failed: {planner_result.error or 'no source model'}" + ) + + state.source_model = planner_result.source_model + state.refresh_plan() + state.add_step( + "planner", + f"planner-done: entities={len(state.entity_order)}, " + f"relationships={len(state.relationship_order)}", + duration_ms=planner_ms, + ) + + # ------------------------------------------------------------------ + # 2. Walk the plan — entities first, then relationships. + # ------------------------------------------------------------------ + state.entity_index = _ontology_index(state.ontology) + state.rel_index = _relationship_index(state.ontology) + state.execute_sql_fn = _wrap_execute_sql(client) + state.total_items_planned = len(state.entity_order) + len( + state.relationship_order + ) + + # ------------------------------------------------------------------ + # Entity walk — three phases: + # 1. concrete classes — independent, run in a bounded thread pool; + # 2. abstract superclasses — derived from the concrete mappings, so they + # run AFTER phase 1 (cheap, no LLM, kept sequential); + # pre-seeded classes are recorded inline and never re-mapped. + # Per-item work is independent and the SQL client is connection-pooled and + # thread-safe, so parallelism is the single biggest wall-clock win. + # ------------------------------------------------------------------ + concrete_items: List[Tuple[str, dict]] = [] + for entity_uri in list(state.entity_order): + ontology_class = state.entity_index.get(entity_uri, {"uri": entity_uri}) + label = ontology_class.get("label") or ontology_class.get("name", entity_uri) + if entity_uri in preseeded_entity_uris: + state.mapping_run_log.append( + {"item": entity_uri, "kind": "entity", "attempts": [], + "final_status": "PRESEEDED"} + ) + state.notify(f"Skipping pre-seeded {label}") + state.items_done += 1 + elif entity_uri not in state.abstract_uris: + concrete_items.append((entity_uri, ontology_class)) + + def _entity_runner(item): + uri, oc = item + return _run_entity_item(state, oc) + + for (entity_uri, ontology_class), outcome in _run_items_concurrently( + concrete_items, _entity_runner + ): + _merge_entity_result(state, entity_uri, ontology_class, outcome) + + # Phase 2 — abstract superclasses (depend on concrete mappings being present). + for entity_uri in list(state.entity_order): + if entity_uri in state.abstract_uris and entity_uri not in preseeded_entity_uris: + ontology_class = state.entity_index.get(entity_uri, {"uri": entity_uri}) + outcome = _run_abstract_item(state, ontology_class) + _merge_entity_result(state, entity_uri, ontology_class, outcome) + + # ------------------------------------------------------------------ + # Relationship walk — every relationship is independent once all entity + # id-universes exist, so the whole set runs in the same bounded pool. + # ------------------------------------------------------------------ + rel_items: List[Tuple[str, dict, dict, dict]] = [] + for property_uri in list(state.relationship_order): + prop = state.rel_index.get(property_uri, {"uri": property_uri}) + label = prop.get("label") or prop.get("name", property_uri) + if property_uri in preseeded_rel_uris: + state.mapping_run_log.append( + {"item": property_uri, "kind": "relationship", "attempts": [], + "final_status": "PRESEEDED"} + ) + state.notify(f"Skipping pre-seeded {label}") + state.items_done += 1 + continue + # Coverage is engine-enforced: resolve each endpoint to a full mapping + # or a synthetic id-universe (from canonical_ids) so a relationship is + # never silently skipped for a missing endpoint. + source_em = _endpoint_em(state, prop.get("domain", "") or "") + target_em = _endpoint_em(state, prop.get("range", "") or "") + if source_em is None or target_em is None: + missing = "source" if source_em is None else "target" + state.mapping_run_log.append( + {"item": property_uri, "kind": "relationship", "attempts": [], + "final_status": "FAIL_NO_ENDPOINT"} + ) + state.add_step( + "evaluator", + f"relationship {property_uri}: no {missing} id-universe — cannot attempt", + ) + state.notify(f"Cannot map {label}: {missing} endpoint has no id universe") + state.items_done += 1 + continue + rel_items.append((property_uri, prop, source_em, target_em)) + + def _rel_runner(item): + _uri, prop, source_em, target_em = item + return _run_relationship_item(state, prop, source_em, target_em) + + for (property_uri, prop, _s, _t), outcome in _run_items_concurrently( + rel_items, _rel_runner + ): + _merge_relationship_result(state, property_uri, prop, outcome) + + state.notify("Agent completed!", pct=100) + return state.finalise() + + +# ===================================================== +# Run-scoped mutable state +# ===================================================== + + +@dataclass +class _RunState: + """Encapsulates per-call mutable state — keeps ``run_agent`` re-entrant. + + All counters, mapping lists, and accumulators that need to evolve as the + walk progresses live here so the orchestrator never relies on module- + level globals. This also keeps the per-item helpers (``_run_*_item``) + pure functions of state + item input. + """ + + host: str + token: str + endpoint_name: str + client: Any + metadata: dict + ontology: dict + documents: List[Any] + on_step: Optional[Callable[[str, int], None]] + max_iterations: Optional[int] + skip_semantic_critic: bool + + # Output accumulators + entity_mappings: List[dict] = field(default_factory=list) + relationship_mappings: List[dict] = field(default_factory=list) + entity_mapping_by_uri: Dict[str, dict] = field(default_factory=dict) + mapping_run_log: List[dict] = field(default_factory=list) + mapping_evaluations: Dict[str, dict] = field(default_factory=dict) + steps: List[AgentStep] = field(default_factory=list) + usage: Dict[str, int] = field( + default_factory=lambda: {"prompt_tokens": 0, "completion_tokens": 0} + ) + iterations: int = 0 + submitted_any: bool = False + + # Plan-derived state — refreshed on (re)plan. + source_model: Optional[SourceModel] = None + entity_order: List[str] = field(default_factory=list) + relationship_order: List[str] = field(default_factory=list) + skip_reasons: Dict[str, str] = field(default_factory=dict) + abstract_uris: set = field(default_factory=set) + planner_reinvocations: int = 0 + + # Walk progress + items_done: int = 0 + total_items_planned: int = 0 + + # Per-run caches & lookups + id_universe_cache: Dict[str, set] = field(default_factory=dict) + entity_index: Dict[str, dict] = field(default_factory=dict) + rel_index: Dict[str, dict] = field(default_factory=dict) + execute_sql_fn: Optional[Callable[[str], dict]] = None + + # Guards the shared accumulators (steps/usage/iterations) that per-item + # runners touch while the entity/relationship walks run them in a pool. + _lock: Any = field(default_factory=threading.Lock, repr=False, compare=False) + _replan_lock: Any = field( + default_factory=threading.Lock, repr=False, compare=False + ) + _max_pct: int = 0 + + # -- helpers ---------------------------------------------------------- + + def add_step( + self, + step_type: str, + content: str, + *, + tool_name: str = "", + duration_ms: int = 0, + ) -> None: + with self._lock: + self.steps.append( + AgentStep( + step_type=step_type, + content=content, + tool_name=tool_name, + duration_ms=duration_ms, + ) + ) + + def pct(self) -> int: + total = max(self.total_items_planned, 1) + return min(5 + int((self.items_done / total) * 90), 95) + + def notify(self, msg: str, *, pct: Optional[int] = None) -> None: + actual_pct = pct if pct is not None else self.pct() + # Progress is reported from a thread pool, so clamp to a monotonic + # high-water mark — the bar never visually goes backwards. + with self._lock: + if actual_pct < self._max_pct: + actual_pct = self._max_pct + else: + self._max_pct = actual_pct + logger.info("PGE STEP [%d%%] %s", actual_pct, msg) + if self.on_step: + self.on_step(msg, actual_pct) + + def add_iterations(self, n: int) -> None: + with self._lock: + self.iterations += int(n or 0) + + def accumulate_usage(self, src: Dict[str, int]) -> None: + with self._lock: + for k in ("prompt_tokens", "completion_tokens"): + self.usage[k] = self.usage.get(k, 0) + int((src or {}).get(k, 0)) + + def refresh_plan(self) -> None: + sm = self.source_model + if sm is None: + return + # Coverage is engine-enforced, NOT LLM-discretionary: attempt EVERY + # ontology entity + relationship regardless of what the Planner put in + # mapping_plan.skip. The Planner's order is used only as a hint. + self.entity_order = full_entity_order(self.ontology, sm) + self.relationship_order = full_relationship_order( + self.ontology, self.entity_order, sm + ) + # The Planner may still flag items it judged unmappable; we keep the + # reasons for logging but DO NOT let them remove items from coverage. + self.skip_reasons = {s.item: s.reason for s in sm.mapping_plan.skip} + concrete, abstract = classify(self.ontology, sm) + self.abstract_uris = abstract + logger.info( + "refresh_plan: full coverage — entities=%d (abstract=%d), " + "relationships=%d; planner_skip(advisory)=%d", + len(self.entity_order), + len(abstract), + len(self.relationship_order), + len(self.skip_reasons), + ) + + def replan_once(self) -> bool: + """Re-invoke the Planner once (subject to the global budget). + + Returns ``True`` on success (and updates the plan in place), ``False`` + when the budget is exhausted or the new Planner run failed. + + Serialised by ``_replan_lock`` so concurrent bubbling items (the + entity/relationship walks run in a thread pool) cannot double-invoke + the Planner or corrupt the shared plan state mid-walk. + """ + with self._replan_lock: + return self._replan_once_locked() + + def _replan_once_locked(self) -> bool: + if self.planner_reinvocations >= _PLANNER_REINVOCATION_BUDGET: + return False + self.planner_reinvocations += 1 + self.notify("Re-planning (escalated)…", pct=self.pct()) + self.add_step( + "planner", + f"replan-start (reinvocation #{self.planner_reinvocations})", + ) + t_rp = time.time() + try: + new_result = run_planner( + host=self.host, + token=self.token, + endpoint_name=self.endpoint_name, + client=self.client, + metadata=self.metadata, + ontology=self.ontology, + documents=self.documents, + on_step=None, + ) + except Exception as exc: # noqa: BLE001 + logger.error("Replan raised an exception: %s", exc, exc_info=True) + self.add_step("planner", f"replan-exception: {exc}") + return False + replan_ms = int((time.time() - t_rp) * 1000) + self.add_iterations(new_result.iterations) + self.accumulate_usage(new_result.usage) + if not new_result.success or new_result.source_model is None: + self.add_step( + "planner", + f"replan-fail: {new_result.error}", + duration_ms=replan_ms, + ) + return False + self.source_model = new_result.source_model + self.refresh_plan() + self.add_step("planner", "replan-done", duration_ms=replan_ms) + return True + + def finalise(self, *, error: str = "") -> AgentResult: + """Build the final :class:`AgentResult`.""" + result = AgentResult(success=False) + result.entity_mappings = list(self.entity_mappings) + result.relationship_mappings = list(self.relationship_mappings) + result.steps = list(self.steps) + result.iterations = self.iterations + result.usage = dict(self.usage) + result.mapping_run_log = list(self.mapping_run_log) + result.mapping_evaluations = dict(self.mapping_evaluations) + result.source_model = ( + self.source_model.to_dict() if self.source_model is not None else None + ) + result.stats = { + "total": len(self.entity_order) + len(self.relationship_order), + "entities": len(self.entity_mappings), + "relationships": len(self.relationship_mappings), + "planner_reinvocations": self.planner_reinvocations, + } + if error: + result.error = error + result.success = False + return result + + # Success when at least one mapping was submitted, OR when there was + # nothing to map (legitimate empty run). + nothing_to_map = ( + not self.entity_order and not self.relationship_order + ) + result.success = self.submitted_any or nothing_to_map + if not result.success: + result.error = ( + "no mappings submitted (all items failed or were skipped)" + ) + logger.info( + "===== MAPPING-PGE ENGINE COMPLETE ===== success=%s, entities=%d, " + "relationships=%d, iterations=%d, replans=%d", + result.success, + len(self.entity_mappings), + len(self.relationship_mappings), + self.iterations, + self.planner_reinvocations, + ) + return result + + +# ===================================================== +# Per-item walk helpers +# ===================================================== + + +_Outcome = Tuple[str, List[dict], Optional[dict], Optional[EvalReport]] + + +def _run_items_concurrently(items: List[Any], runner: Callable[[Any], _Outcome]): + """Run ``runner(item)`` for each item in a bounded thread pool and yield + ``(item, outcome)`` pairs in the ORIGINAL item order. + + The per-item runners mutate only thread-safe parts of the shared state + (lock-guarded usage/iteration/step accumulators); all result MERGING into + the run accumulators happens in the caller's thread after each future + resolves, so there is no race on the mapping lists/dicts. + """ + if not items: + return + workers = min(_MAX_CONCURRENCY, len(items)) + if workers <= 1: + for item in items: + yield item, runner(item) + return + with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as pool: + future_to_item = {pool.submit(runner, item): item for item in items} + results: Dict[int, _Outcome] = {} + index = {id(item): i for i, item in enumerate(items)} + for fut in concurrent.futures.as_completed(future_to_item): + item = future_to_item[fut] + try: + results[index[id(item)]] = fut.result() + except Exception as exc: # noqa: BLE001 — surface as a failed item + logger.error("Concurrent item runner raised: %s", exc, exc_info=True) + results[index[id(item)]] = ( + "FAIL_BUDGET", + [{"attempt": 1, "stage1_status": "skipped", + "critic_status": "skipped", "bubble": False, + "hint": None, "error": f"runner exception: {exc}"}], + None, + None, + ) + for i, item in enumerate(items): + yield item, results[i] + + +def _merge_entity_result( + state: "_RunState", entity_uri: str, ontology_class: dict, outcome: _Outcome +) -> None: + """Merge one entity item's outcome into the run accumulators (main thread).""" + final_status, attempts_log, last_mapping, last_report = outcome + label = ontology_class.get("label") or ontology_class.get("name", entity_uri) + state.mapping_run_log.append( + {"item": entity_uri, "kind": "entity", "attempts": attempts_log, + "final_status": final_status} + ) + if final_status == "PASS" and last_mapping is not None: + state.entity_mappings.append(last_mapping) + state.entity_mapping_by_uri[entity_uri] = last_mapping + state.submitted_any = True + if last_report is not None: + state.mapping_evaluations[entity_uri] = last_report.to_dict() + state.notify(f"Mapped {label}") + state.items_done += 1 + + +def _merge_relationship_result( + state: "_RunState", property_uri: str, prop: dict, outcome: _Outcome +) -> None: + """Merge one relationship item's outcome into the run accumulators.""" + final_status, attempts_log, last_mapping, last_report = outcome + label = prop.get("label") or prop.get("name", property_uri) + state.mapping_run_log.append( + {"item": property_uri, "kind": "relationship", "attempts": attempts_log, + "final_status": final_status} + ) + if final_status == "PASS" and last_mapping is not None: + state.relationship_mappings.append(last_mapping) + state.submitted_any = True + if last_report is not None: + state.mapping_evaluations[property_uri] = last_report.to_dict() + state.notify(f"Mapped {label}") + state.items_done += 1 + + +def _run_abstract_item( + state: "_RunState", + ontology_class: dict, +) -> Tuple[str, List[dict], Optional[dict], Optional[EvalReport]]: + """Derive an abstract superclass mapping as the UNION of its concrete + subclass mappings — no LLM call. + + The abstract class (e.g. ``Clinicalencounter``) has no source table of its + own; its instances are exactly the union of its concrete leaf subclasses, + which have already been mapped earlier in the entity walk. Reusing their + verbatim SQL makes the abstract id-universe identical to the union of the + parts, so relationships whose domain/range is the abstract join with zero + dangling. We still run the deterministic evaluator (cheap) to guarantee + unique, non-null ids; the semantic critic is skipped (a mechanical union + has no column-choice ambiguity to audit). + """ + class_uri = ontology_class.get("uri", "") + label = ontology_class.get("label") or ontology_class.get("name", class_uri) + + concrete, _abstract = classify(state.ontology, state.source_model) + leaf_uris = concrete_leaf_descendants(class_uri, state.ontology, concrete) + sub_ems = [ + state.entity_mapping_by_uri[u] + for u in leaf_uris + if u in state.entity_mapping_by_uri + ] + state.notify( + f"Deriving {label} as UNION of {len(sub_ems)}/{len(leaf_uris)} " + "concrete subclass(es)…", + pct=state.pct(), + ) + state.add_step( + "generator", + f"abstract-derive: {class_uri} from {len(sub_ems)} mapped subclasses", + ) + + mapping = build_abstract_union_mapping(class_uri, ontology_class, sub_ems) + if mapping is None: + attempts_log = [ + { + "attempt": 1, + "generator_ms": 0, + "stage1_status": "skipped", + "critic_status": "skipped", + "bubble": False, + "hint": None, + "error": "no mapped concrete subclasses to union", + } + ] + return "FAIL_BUDGET", attempts_log, None, None + + t_e = time.time() + try: + report = evaluate_entity_mapping( + mapping=mapping, + ontology_class=ontology_class, + execute_sql_fn=state.execute_sql_fn, + ) + except Exception as exc: # noqa: BLE001 + logger.error("Abstract derive eval raised on %s: %s", class_uri, exc) + attempts_log = [ + { + "attempt": 1, + "generator_ms": 0, + "stage1_status": "skipped", + "critic_status": "skipped", + "bubble": False, + "hint": None, + "error": f"eval exception: {exc}", + } + ] + return "FAIL_BUDGET", attempts_log, mapping, None + eval_ms = int((time.time() - t_e) * 1000) + attempts_log = [ + { + "attempt": 1, + "generator_ms": 0, + "stage1_status": report.status, + "critic_status": "skipped", + "bubble": False, + "hint": _first_hint(report), + } + ] + state.add_step( + "evaluator", + f"abstract-stage1: {class_uri} status={report.status}", + duration_ms=eval_ms, + ) + if report.status == "PASS": + return "PASS", attempts_log, mapping, report + return "FAIL_BUDGET", attempts_log, mapping, report + + +def _run_entity_item( + state: _RunState, + ontology_class: dict, +) -> Tuple[str, List[dict], Optional[dict], Optional[EvalReport]]: + """Run the G->E loop for one entity. + + Returns ``(final_status, attempts_log, last_mapping, last_report)``. + The outer ``while True`` lets a successful replan restart the inner + retry budget fresh, which is the intent of the bubble-to-planner path. + """ + class_uri = ontology_class.get("uri", "") + class_label = ontology_class.get("label") or ontology_class.get( + "name", class_uri + ) + attempts_log: List[dict] = [] + last_mapping: Optional[dict] = None + last_report: Optional[EvalReport] = None + + while True: + retry_hint: Optional[str] = None + bubble_requested = False + for attempt_idx in range(_PER_ITEM_GENERATOR_ATTEMPTS): + attempt_num = attempt_idx + 1 + slice_dict = _slice_for_entity(state.source_model, class_uri) + state.notify( + f"Mapping {class_label} (attempt {attempt_num}/{_PER_ITEM_GENERATOR_ATTEMPTS})…", + pct=state.pct(), + ) + state.add_step( + "generator", + f"entity-gen-start: {class_uri} attempt {attempt_num}", + ) + t_g = time.time() + try: + gen_result = run_entity_generator( + host=state.host, + token=state.token, + endpoint_name=state.endpoint_name, + client=state.client, + ontology_class=ontology_class, + source_model_slice=slice_dict, + retry_hint=retry_hint, + on_step=None, + **( + {"max_iterations": state.max_iterations} + if state.max_iterations is not None + else {} + ), + ) + except Exception as exc: # noqa: BLE001 + logger.error( + "EntityGenerator raised on %s attempt %d: %s", + class_uri, + attempt_num, + exc, + exc_info=True, + ) + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": int((time.time() - t_g) * 1000), + "stage1_status": "skipped", + "critic_status": "skipped", + "bubble": False, + "hint": None, + "error": f"generator exception: {exc}", + } + ) + continue + gen_ms = int((time.time() - t_g) * 1000) + state.add_iterations(gen_result.iterations) + state.accumulate_usage(gen_result.usage) + + if not gen_result.success or gen_result.mapping is None: + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "skipped", + "critic_status": "skipped", + "bubble": False, + "hint": None, + "error": gen_result.error or "generator failed", + } + ) + state.add_step( + "generator", + f"entity-gen-fail: {class_uri} attempt {attempt_num}: " + f"{gen_result.error}", + duration_ms=gen_ms, + ) + retry_hint = gen_result.error or retry_hint + continue + + mapping = gen_result.mapping + last_mapping = mapping + + state.notify(f"Evaluating {class_label}…", pct=state.pct()) + t_e = time.time() + stage1_report = evaluate_entity_mapping( + mapping=mapping, + ontology_class=ontology_class, + execute_sql_fn=state.execute_sql_fn, + ) + eval_ms = int((time.time() - t_e) * 1000) + last_report = stage1_report + state.add_step( + "evaluator", + f"entity-stage1: {class_uri} status={stage1_report.status} " + f"bubble={stage1_report.bubble_to_planner}", + duration_ms=eval_ms, + ) + + if stage1_report.status == "FAIL": + hint = _first_hint(stage1_report) + bubble = bool(stage1_report.bubble_to_planner) + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "FAIL", + "critic_status": "skipped", + "bubble": bubble, + "hint": hint, + } + ) + if bubble: + bubble_requested = True + break + retry_hint = hint or retry_hint + continue + + # Stage 1 PASS — optionally run the critic. + if state.skip_semantic_critic: + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "PASS", + "critic_status": "skipped", + "bubble": False, + "hint": None, + } + ) + return "PASS", attempts_log, mapping, stage1_report + + state.notify(f"Critiquing {class_label}…", pct=state.pct()) + t_c = time.time() + try: + critic_result = run_critic( + host=state.host, + token=state.token, + endpoint_name=state.endpoint_name, + client=state.client, + item_kind="entity", + item_uri=class_uri, + item_definition=ontology_class, + submitted_mapping=mapping, + source_model_slice=slice_dict, + stage1_metrics=dict(stage1_report.metrics), + ) + except Exception as exc: # noqa: BLE001 + logger.error( + "Critic raised on %s attempt %d: %s", + class_uri, + attempt_num, + exc, + exc_info=True, + ) + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "PASS", + "critic_status": "skipped", + "bubble": False, + "hint": None, + "error": f"critic exception: {exc}", + } + ) + return "PASS", attempts_log, mapping, stage1_report + critic_ms = int((time.time() - t_c) * 1000) + state.add_iterations(critic_result.iterations) + state.accumulate_usage(critic_result.usage) + + critic_report = critic_result.report + state.add_step( + "critic", + f"entity-critic: {class_uri} status=" + f"{critic_report.status if critic_report else '?'} " + f"bubble=" + f"{critic_report.bubble_to_planner if critic_report else '?'}", + duration_ms=critic_ms, + ) + + if not critic_result.success or critic_report is None: + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "PASS", + "critic_status": "skipped", + "bubble": False, + "hint": None, + "error": critic_result.error or "critic failed", + } + ) + return "PASS", attempts_log, mapping, stage1_report + + if critic_report.status == "PASS": + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "PASS", + "critic_status": "PASS", + "bubble": False, + "hint": None, + } + ) + return "PASS", attempts_log, mapping, critic_report + + hint = _first_hint(critic_report) + bubble = bool(critic_report.bubble_to_planner) + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "PASS", + "critic_status": "FAIL", + "bubble": bubble, + "hint": hint, + } + ) + last_report = critic_report + if bubble: + bubble_requested = True + break + retry_hint = hint or retry_hint + continue + + if bubble_requested: + if state.replan_once(): + continue # restart the item with the new plan + return "FAIL_BUBBLE", attempts_log, last_mapping, last_report + return "FAIL_BUDGET", attempts_log, last_mapping, last_report + + +def _run_relationship_item( + state: _RunState, + ontology_property: dict, + source_em: dict, + target_em: dict, +) -> Tuple[str, List[dict], Optional[dict], Optional[EvalReport]]: + """Run the G->E loop for one relationship. + + Returns ``(final_status, attempts_log, last_mapping, last_report)``. + """ + property_uri = ontology_property.get("uri", "") + property_label = ontology_property.get("label") or ontology_property.get( + "name", property_uri + ) + attempts_log: List[dict] = [] + last_mapping: Optional[dict] = None + last_report: Optional[EvalReport] = None + + while True: + retry_hint: Optional[str] = None + bubble_requested = False + for attempt_idx in range(_PER_ITEM_GENERATOR_ATTEMPTS): + attempt_num = attempt_idx + 1 + slice_dict = _slice_for_relationship( + state.source_model, + property_uri, + source_em, + target_em, + ) + state.notify( + f"Mapping {property_label} (attempt {attempt_num}/" + f"{_PER_ITEM_GENERATOR_ATTEMPTS})…", + pct=state.pct(), + ) + state.add_step( + "generator", + f"rel-gen-start: {property_uri} attempt {attempt_num}", + ) + t_g = time.time() + try: + gen_result = run_relationship_generator( + host=state.host, + token=state.token, + endpoint_name=state.endpoint_name, + client=state.client, + ontology_property=ontology_property, + source_entity_mapping=source_em, + target_entity_mapping=target_em, + source_model_slice=slice_dict, + retry_hint=retry_hint, + on_step=None, + **( + {"max_iterations": state.max_iterations} + if state.max_iterations is not None + else {} + ), + ) + except Exception as exc: # noqa: BLE001 + logger.error( + "RelationshipGenerator raised on %s attempt %d: %s", + property_uri, + attempt_num, + exc, + exc_info=True, + ) + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": int((time.time() - t_g) * 1000), + "stage1_status": "skipped", + "critic_status": "skipped", + "bubble": False, + "hint": None, + "error": f"generator exception: {exc}", + } + ) + continue + gen_ms = int((time.time() - t_g) * 1000) + state.add_iterations(gen_result.iterations) + state.accumulate_usage(gen_result.usage) + + if not gen_result.success or gen_result.mapping is None: + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "skipped", + "critic_status": "skipped", + "bubble": False, + "hint": None, + "error": gen_result.error or "generator failed", + } + ) + state.add_step( + "generator", + f"rel-gen-fail: {property_uri} attempt {attempt_num}: " + f"{gen_result.error}", + duration_ms=gen_ms, + ) + retry_hint = gen_result.error or retry_hint + continue + + mapping = gen_result.mapping + last_mapping = mapping + + state.notify(f"Evaluating {property_label}…", pct=state.pct()) + t_e = time.time() + stage1_report = evaluate_relationship_mapping( + mapping=mapping, + source_entity_mapping=source_em, + target_entity_mapping=target_em, + execute_sql_fn=state.execute_sql_fn, + id_universe_cache=state.id_universe_cache, + ) + eval_ms = int((time.time() - t_e) * 1000) + last_report = stage1_report + state.add_step( + "evaluator", + f"rel-stage1: {property_uri} status={stage1_report.status} " + f"bubble={stage1_report.bubble_to_planner}", + duration_ms=eval_ms, + ) + + if stage1_report.status == "FAIL": + hint = _first_hint(stage1_report) + bubble = bool(stage1_report.bubble_to_planner) + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "FAIL", + "critic_status": "skipped", + "bubble": bubble, + "hint": hint, + } + ) + if bubble: + bubble_requested = True + break + retry_hint = hint or retry_hint + continue + + if state.skip_semantic_critic: + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "PASS", + "critic_status": "skipped", + "bubble": False, + "hint": None, + } + ) + return "PASS", attempts_log, mapping, stage1_report + + state.notify(f"Critiquing {property_label}…", pct=state.pct()) + t_c = time.time() + try: + critic_result = run_critic( + host=state.host, + token=state.token, + endpoint_name=state.endpoint_name, + client=state.client, + item_kind="relationship", + item_uri=property_uri, + item_definition=ontology_property, + submitted_mapping=mapping, + source_model_slice=slice_dict, + stage1_metrics=dict(stage1_report.metrics), + ) + except Exception as exc: # noqa: BLE001 + logger.error( + "Critic raised on %s attempt %d: %s", + property_uri, + attempt_num, + exc, + exc_info=True, + ) + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "PASS", + "critic_status": "skipped", + "bubble": False, + "hint": None, + "error": f"critic exception: {exc}", + } + ) + return "PASS", attempts_log, mapping, stage1_report + critic_ms = int((time.time() - t_c) * 1000) + state.add_iterations(critic_result.iterations) + state.accumulate_usage(critic_result.usage) + + critic_report = critic_result.report + state.add_step( + "critic", + f"rel-critic: {property_uri} status=" + f"{critic_report.status if critic_report else '?'} " + f"bubble=" + f"{critic_report.bubble_to_planner if critic_report else '?'}", + duration_ms=critic_ms, + ) + + if not critic_result.success or critic_report is None: + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "PASS", + "critic_status": "skipped", + "bubble": False, + "hint": None, + "error": critic_result.error or "critic failed", + } + ) + return "PASS", attempts_log, mapping, stage1_report + + if critic_report.status == "PASS": + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "PASS", + "critic_status": "PASS", + "bubble": False, + "hint": None, + } + ) + return "PASS", attempts_log, mapping, critic_report + + hint = _first_hint(critic_report) + bubble = bool(critic_report.bubble_to_planner) + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "PASS", + "critic_status": "FAIL", + "bubble": bubble, + "hint": hint, + } + ) + last_report = critic_report + if bubble: + bubble_requested = True + break + retry_hint = hint or retry_hint + continue + + if bubble_requested: + if state.replan_once(): + continue + return "FAIL_BUBBLE", attempts_log, last_mapping, last_report + return "FAIL_BUDGET", attempts_log, last_mapping, last_report diff --git a/src/agents/agent_mapping_pge/evaluator/__init__.py b/src/agents/agent_mapping_pge/evaluator/__init__.py new file mode 100644 index 00000000..41e7ef5e --- /dev/null +++ b/src/agents/agent_mapping_pge/evaluator/__init__.py @@ -0,0 +1,18 @@ +"""Evaluator stage of the mapping PGE pipeline. + +Stage 1 (this module) is the *deterministic* evaluator — pure-Python checks +backed by SQL counts. Stage 2 (added in a later sprint) is the semantic +evaluator that uses an LLM to judge naming/semantic fidelity. + +The deterministic checks live in :mod:`agents.agent_mapping_pge.evaluator.deterministic`. +""" + +from agents.agent_mapping_pge.evaluator.deterministic import ( + evaluate_entity_mapping, + evaluate_relationship_mapping, +) + +__all__ = [ + "evaluate_entity_mapping", + "evaluate_relationship_mapping", +] diff --git a/src/agents/agent_mapping_pge/evaluator/critic.py b/src/agents/agent_mapping_pge/evaluator/critic.py new file mode 100644 index 00000000..4c7db9a2 --- /dev/null +++ b/src/agents/agent_mapping_pge/evaluator/critic.py @@ -0,0 +1,747 @@ +""" +OntoBricks Mapping-PGE Semantic Critic Agent. + +Sprint 6 of the Planner-Generator-Evaluator (PGE) redesign — stage 2 of the +Evaluator. Runs ONLY after the deterministic (stage-1) evaluator has passed. + +The Critic audits ONE submitted mapping for SEMANTIC correctness — issues that +pure structural checks cannot catch: + +* the WRONG TABLE was picked (e.g. ``antenatal_visits`` chosen to realise + the ``Delivery`` class), or +* the wrong COLUMN within the right table (e.g. ``appointment_date`` used + for ``deliveryDate``). + +The Critic's "bubble" signal is sharp: if the wrong TABLE was chosen, the +verdict bubbles to the Planner (which must revise the source model); if just +a wrong column inside the right table, the verdict stays with the Generator +which can retry against the same table. + +The loop shape mirrors :mod:`agents.agent_mapping_pge.generators.entity` — +same ``call_serving_endpoint`` + ``dispatch_tool`` ReAct cycle, same 3-second +inter-iteration delay, same MLflow trace decorator. Differences: + +* Smaller default budget (6) — auditing is bounded work; if the Critic can't + conclude in 6 iterations, it defers (PASS with a reasoning note) rather + than falsely escalates. +* Different tool set: only ``sample_table``, ``execute_sql``, + ``get_documents_context``, and the terminal ``submit_evaluation``. +* No single-shot fallback — the Critic produces structured output through + ``submit_evaluation`` only. +""" + +import json +import time +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional + +import requests + +if TYPE_CHECKING: + from agents.agent_mapping_pge.contracts import EvalReport + +from back.core.logging import get_logger +from agents.engine_base import ( + accumulate_usage, + call_serving_endpoint, + dispatch_tool, +) +from agents.tools.context import ToolContext +from agents.tools.documents import ( + GET_DOCUMENTS_CONTEXT_DEF, + tool_get_documents_context, +) +from agents.tools.evaluation import ( + EVALUATION_TOOL_DEFINITIONS, + EVALUATION_TOOL_HANDLERS, +) +from agents.tools.planner import ( + SAMPLE_TABLE_DEF, + tool_sample_table, +) +from agents.tools.sql import ( + SQL_TOOL_DEFINITIONS, + SQL_TOOL_HANDLERS, +) +from agents.tracing import trace_agent + +logger = get_logger(__name__) + +MAX_ITERATIONS = 6 +LLM_TIMEOUT = 180 +_ITERATION_DELAY_SEC = 1 +# See planner._MAX_TOKENS comment — same rationale for submit_evaluation. +_MAX_TOKENS = 50000 + +_TRACE_NAME = "mapping_pge_critic" + + +# ===================================================== +# Tool aggregation +# ===================================================== +# +# The Critic only needs: +# * sample_table – peek at actual values to verify the column +# picked really represents the ontology concept. +# * execute_sql – targeted probes for "is this column really +# what it claims" sanity checks. +# * get_documents_context – consult any imported domain glossary. +# * submit_evaluation – TERMINAL. +# +# We deliberately exclude: +# * get_metadata / get_ontology — the audit target is supplied in the user +# prompt; broad re-fetches just inflate context. +# * column_value_overlap / distinct_count — those are structural, already +# covered by the deterministic stage. +# * submit_source_model / submit_entity_mapping / submit_relationship_mapping +# — wrong stage. + +TOOL_DEFINITIONS: List[dict] = ( + [SAMPLE_TABLE_DEF, GET_DOCUMENTS_CONTEXT_DEF] + + SQL_TOOL_DEFINITIONS + + EVALUATION_TOOL_DEFINITIONS +) + +TOOL_HANDLERS: Dict[str, Callable] = { + "sample_table": tool_sample_table, + "get_documents_context": tool_get_documents_context, + **SQL_TOOL_HANDLERS, + **EVALUATION_TOOL_HANDLERS, +} + + +# ===================================================== +# Data classes +# ===================================================== + + +@dataclass +class CriticStep: + """One observable step of the Critic's execution. + + Mirrors :class:`agents.agent_mapping_pge.generators.entity.EntityGenStep` + so the orchestrator (Sprint 7) can render a per-audit timeline in the UI. + """ + + step_type: str # "tool_call" | "tool_result" | "output" + content: str + tool_name: str = "" + duration_ms: int = 0 + + +@dataclass +class CriticResult: + """Outcome of a single Critic invocation. + + ``report`` is populated when the agent terminated by submitting a verdict + via ``submit_evaluation``. ``success`` here is the agent-level success + flag — it indicates a *clean termination*, NOT a PASS verdict. A FAIL + verdict with ``bubble_to_planner=True`` still has ``success=True``. + ``success=False`` is reserved for budget exhaustion, text-only output, + and LLM/transport errors. + """ + + success: bool + report: Optional["EvalReport"] = None + steps: List[CriticStep] = field(default_factory=list) + iterations: int = 0 + error: str = "" + usage: Dict[str, int] = field(default_factory=dict) + + +# ===================================================== +# System prompt +# ===================================================== +# +# Kept under 3KB. Frames the Critic as a senior data engineer auditing ONE +# submitted mapping for SEMANTIC correctness — the structural checks have +# already passed. The decision rubric (PASS / FAIL+no-bubble / FAIL+bubble) +# is load-bearing: it determines whether the orchestrator retries the +# Generator or re-invokes the Planner. + +SYSTEM_PROMPT = """\ +You are a senior data engineer auditing ONE submitted mapping for SEMANTIC \ +correctness. The structural checks (row counts, distinct IDs, dangling FKs) \ +have ALREADY PASSED — your job is to catch wrong-concept errors that pure \ +structural checks cannot see. + +WHAT YOU AUDIT +• Did the mapping pick the RIGHT TABLE for the ontology class/property? +• Do sampled values in the chosen column(s) actually represent what the \ +ontology attribute means? (e.g. "delivery_date" should be a delivery date, \ +not a booking date.) +• Does the column's semantics match the ontology comment / label? + +TOOLS +You have these tools: + • sample_table – Up to N random rows from a table. Use to peek at \ +actual values and check they match the concept. + • execute_sql – Targeted SQL for "is this column really what it \ +claims" probes (e.g. value ranges, distinct categories, null patterns). + • get_documents_context – Imported domain glossaries / data dictionaries. \ +Check against these when the column's role is non-obvious. + • submit_evaluation – TERMINAL. Call EXACTLY ONCE when you have a \ +confident verdict. + +DECISION RUBRIC +• PASS — sampled values, column semantics, and domain context all support \ +the mapping. status="PASS", failures=[], bubble_to_planner=false. +• FAIL with bubble_to_planner=false — the WRONG COLUMN was picked within \ +the RIGHT TABLE. The Generator can fix this on retry. Populate failures[] \ +with the specific column-level issue and a concrete hint. +• FAIL with bubble_to_planner=true — the WRONG TABLE was chosen entirely. \ +The Planner must revise the source model. Populate failures[] and set the \ +bubble flag. + +HINT DISCIPLINE +• Hints must be CONCRETE, ACTIONABLE, single-sentence corrections. +• Good column-level hint: "Sampled rows show `appointment_date` is the \ +booking date, not delivery date. Use `delivery_dttm` instead." +• Good table-level hint: "This mapping uses `antenatal_visits`, but the \ +chosen class is Delivery. Switch to the `labour_delivery` table." +• Bad hint (vague): "consider using a different column" +• Bad hint (chatty): "I think there might be an issue here, you should look \ +into it more carefully" + +HARD RULES +• You are bounded by max_iterations=6. Keep your audit FOCUSED — pick the \ +one or two probes that would change your verdict, not exhaustive ones. +• Call submit_evaluation EXACTLY ONCE. +• If you cannot determine a verdict within 6 iterations, submit PASS with a \ +reasoning note explaining the uncertainty. Do NOT bubble — better to defer \ +than to falsely escalate. +• Do not call get_metadata, get_ontology, column_value_overlap, \ +distinct_count, submit_source_model, submit_entity_mapping, or \ +submit_relationship_mapping — they are not available to you. The audit \ +target and structural metrics are already in the user message. +""" + + +# ===================================================== +# Internal helpers +# ===================================================== + + +def _format_entity_definition(item_definition: dict) -> List[str]: + """Lines for an entity (ontology class) audit target.""" + parts: List[str] = [] + label = item_definition.get("label") or item_definition.get("name", "") + comment = item_definition.get("comment", "") or "" + attributes = item_definition.get("attributes", []) or [] + + parts.append(f" label: {label}") + if comment: + parts.append(f" comment: {comment}") + if attributes: + parts.append(f" attributes ({len(attributes)}):") + for attr in attributes: + if isinstance(attr, dict): + a_name = attr.get("name") or attr.get("label") or attr.get("uri", "?") + a_type = attr.get("type") or attr.get("range") or "" + parts.append(f" - {a_name}" + (f" ({a_type})" if a_type else "")) + else: + parts.append(f" - {attr}") + return parts + + +def _format_relationship_definition(item_definition: dict) -> List[str]: + """Lines for a relationship (ontology property) audit target. + + Always emits explicit ``domain`` and ``range`` lines — these are what + differentiate a relationship audit from an entity audit, and the tests + pin them. + """ + parts: List[str] = [] + label = item_definition.get("label") or item_definition.get("name", "") + comment = item_definition.get("comment", "") or "" + domain = item_definition.get("domain", "") or "" + range_class = item_definition.get("range", "") or "" + + parts.append(f" label: {label}") + if comment: + parts.append(f" comment: {comment}") + parts.append(f" domain: {domain}") + parts.append(f" range: {range_class}") + return parts + + +def _format_submitted_entity_mapping(submitted_mapping: dict) -> List[str]: + """Lines summarising an entity mapping under audit.""" + parts: List[str] = ["SUBMITTED MAPPING (entity)"] + parts.append(f" sql_query: {submitted_mapping.get('sql_query', '')}") + parts.append(f" id_column: {submitted_mapping.get('id_column', '')}") + parts.append(f" label_column: {submitted_mapping.get('label_column', '')}") + attr_map = submitted_mapping.get("attribute_mappings", {}) or {} + if attr_map: + parts.append(" attribute_mappings:") + for k, v in attr_map.items(): + parts.append(f" {k} -> {v}") + unmapped = submitted_mapping.get("unmapped_attributes", []) or [] + if unmapped: + parts.append(" unmapped_attributes:") + for u in unmapped: + if isinstance(u, dict): + parts.append( + f" - {u.get('name', '?')}: {u.get('reason', '')}" + ) + else: + parts.append(f" - {u}") + return parts + + +def _format_submitted_relationship_mapping(submitted_mapping: dict) -> List[str]: + """Lines summarising a relationship mapping under audit.""" + parts: List[str] = ["SUBMITTED MAPPING (relationship)"] + parts.append(f" sql_query: {submitted_mapping.get('sql_query', '')}") + parts.append( + f" source_id_column: {submitted_mapping.get('source_id_column', '')}" + ) + parts.append( + f" target_id_column: {submitted_mapping.get('target_id_column', '')}" + ) + parts.append( + f" source_class: {submitted_mapping.get('source_class', '') or submitted_mapping.get('domain', '')}" + ) + parts.append( + f" target_class: {submitted_mapping.get('target_class', '') or submitted_mapping.get('range_class', '')}" + ) + return parts + + +def _build_user_prompt( + item_kind: str, + item_uri: str, + item_definition: dict, + submitted_mapping: dict, + source_model_slice: dict, + stage1_metrics: dict, +) -> str: + """Render the audit user prompt. + + Structure: + 1. AUDIT TARGET — item_kind, URI, ontology metadata (label/comment, + attributes for entities; domain/range for relationships). + 2. SUBMITTED MAPPING — the actual mapping under audit. + 3. PLANNER'S PREDICTION — the slice the Planner curated for this item. + 4. STRUCTURAL CHECK METRICS (PASSED) — context from stage 1. + 5. YOUR TASK — short reminder of the rubric. + """ + parts: List[str] = [] + + parts.append("AUDIT TARGET") + parts.append(f" kind: {item_kind}") + parts.append(f" uri: {item_uri}") + if item_kind == "relationship": + parts.extend(_format_relationship_definition(item_definition or {})) + else: + parts.extend(_format_entity_definition(item_definition or {})) + + parts.append("") + if item_kind == "relationship": + parts.extend(_format_submitted_relationship_mapping(submitted_mapping or {})) + else: + parts.extend(_format_submitted_entity_mapping(submitted_mapping or {})) + + parts.append("") + parts.append("PLANNER'S PREDICTION") + parts.append(json.dumps(source_model_slice or {}, indent=2, default=str)) + + parts.append("") + parts.append("STRUCTURAL CHECK METRICS (PASSED)") + parts.append(json.dumps(stage1_metrics or {}, indent=2, default=str)) + + parts.append("") + parts.append("YOUR TASK") + parts.append( + "Audit the SEMANTIC correctness of the submitted mapping. Use " + "sample_table / execute_sql / get_documents_context as needed, then " + "call submit_evaluation EXACTLY ONCE with your verdict. Follow the " + "PASS / FAIL(no bubble) / FAIL(bubble) rubric in the system prompt." + ) + + prompt = "\n".join(parts) + logger.debug( + "_build_user_prompt for %s=%s (%d chars):\n%s", + item_kind, + item_uri, + len(prompt), + prompt, + ) + return prompt + + +# ===================================================== +# Public entry point +# ===================================================== + + +@trace_agent(name="mapping_pge_critic") +def run_critic( + host: str, + token: str, + endpoint_name: str, + client: Any, + *, + item_kind: str, + item_uri: str, + item_definition: dict, + submitted_mapping: dict, + source_model_slice: dict, + stage1_metrics: dict, + documents: Optional[list] = None, + on_step: Optional[Callable[[str, int], None]] = None, + max_iterations: int = MAX_ITERATIONS, +) -> CriticResult: + """Run the Semantic Critic agent for one submitted mapping. + + The Critic autonomously audits ``submitted_mapping`` for semantic + correctness using ``sample_table`` / ``execute_sql`` / + ``get_documents_context``, then submits a verdict via the terminal + ``submit_evaluation`` tool. The resulting :class:`EvalReport` (stage + ``"semantic"``) is stored on ``ctx.semantic_eval_report`` and returned in + ``CriticResult.report``. + + Args: + host: Databricks workspace URL. + token: Bearer token for the serving endpoint. + endpoint_name: Foundation Model serving endpoint name. + client: Databricks SQL client (must expose ``execute_query(sql)``). + item_kind: ``"entity"`` or ``"relationship"``. + item_uri: The ontology class or property URI under audit. + item_definition: Full ontology dict for the item (label/comment, + plus attributes for entities or domain/range for relationships). + submitted_mapping: The mapping under audit (handler dict shape). + source_model_slice: The Planner's slice for this item. + stage1_metrics: Metrics from the deterministic evaluator, for + context. + documents: Optional pre-loaded domain documents — surfaced via + ``get_documents_context``. + on_step: Optional progress callback ``(msg, pct)`` for UI updates. + max_iterations: Upper bound on tool-call iterations (default 6 — + smaller than the Generators because auditing is bounded work). + + Returns: + A :class:`CriticResult`. ``success`` is True iff the Critic + terminated by submitting a verdict; in that case ``report`` holds + the resulting :class:`EvalReport`. On failure (budget exhaustion, + text-only output, transport error), ``error`` explains why. + """ + iteration_limit = max_iterations if max_iterations is not None else MAX_ITERATIONS + + logger.info( + "===== CRITIC START ===== endpoint=%s, kind=%s, uri=%s, max_iter=%d", + endpoint_name, + item_kind, + item_uri, + iteration_limit, + ) + + ctx = ToolContext( + host=host.rstrip("/"), + token=token, + client=client, + # The audit target is in the user prompt; metadata/ontology are not + # needed by the Critic's tools. + metadata={}, + ontology={}, + documents=list(documents or []), + ) + + result = CriticResult(success=False) + + user_prompt = _build_user_prompt( + item_kind=item_kind, + item_uri=item_uri, + item_definition=item_definition or {}, + submitted_mapping=submitted_mapping or {}, + source_model_slice=source_model_slice or {}, + stage1_metrics=stage1_metrics or {}, + ) + messages: List[dict] = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ] + logger.info( + "Critic conversation initialized: system=%d chars, user=%d chars", + len(SYSTEM_PROMPT), + len(user_prompt), + ) + + total_usage: Dict[str, int] = {"prompt_tokens": 0, "completion_tokens": 0} + + def _progress_pct(iteration_idx: int) -> int: + ratio = (iteration_idx + 1) / max(iteration_limit, 1) + return min(5 + int(ratio * 90), 95) + + def notify(msg: str, *, pct: Optional[int] = None) -> None: + actual_pct = pct if pct is not None else 5 + logger.info("CRITIC STEP [%d%%] %s", actual_pct, msg) + if on_step: + on_step(msg, actual_pct) + + notify(f"Auditing {item_kind} {item_uri}…", pct=1) + + # ------------------------------------------------------------------ + # Agent loop + # ------------------------------------------------------------------ + for iteration in range(iteration_limit): + if iteration > 0: + logger.debug( + "Iteration %d: waiting %ds before LLM call (rate limit mitigation)", + iteration + 1, + _ITERATION_DELAY_SEC, + ) + time.sleep(_ITERATION_DELAY_SEC) + + current_iteration = iteration + 1 + pct = _progress_pct(iteration) + logger.info( + "----- Critic iteration %d/%d — %d messages, report=%s -----", + current_iteration, + iteration_limit, + len(messages), + "set" if ctx.semantic_eval_report is not None else "unset", + ) + notify( + f"Critic iteration {current_iteration}/{iteration_limit}…", + pct=pct, + ) + + t0 = time.time() + try: + llm_response = call_serving_endpoint( + host, + token, + endpoint_name, + messages, + tools=TOOL_DEFINITIONS, + max_tokens=_MAX_TOKENS, + temperature=0.1, + timeout=LLM_TIMEOUT, + trace_name=_TRACE_NAME, + ) + except requests.exceptions.HTTPError as exc: + status = exc.response.status_code if exc.response is not None else "?" + logger.warning( + "Critic iteration %d: HTTPError status=%s", + current_iteration, + status, + ) + logger.debug( + "Critic iteration %d: HTTPError body: %.500s", + current_iteration, + exc.response.text if exc.response is not None else "N/A", + ) + if exc.response is not None and status in (400, 422): + result.error = "LLM endpoint does not support function calling" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "Critic: endpoint refused tools — cannot produce an evaluation" + ) + return result + result.error = f"LLM request failed: {exc}" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "Critic: LLM request failed at iteration %d: %s", + current_iteration, + exc, + ) + return result + except requests.exceptions.ReadTimeout: + result.error = f"LLM request timed out after {LLM_TIMEOUT}s" + result.iterations = current_iteration + result.usage = total_usage + logger.error("Critic: timeout at iteration %d", current_iteration) + return result + except requests.exceptions.RequestException as exc: + result.error = f"LLM request failed: {exc}" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "Critic: request exception at iteration %d: %s", + current_iteration, + exc, + ) + return result + + elapsed_ms = int((time.time() - t0) * 1000) + logger.info( + "Critic iteration %d: LLM responded in %dms", + current_iteration, + elapsed_ms, + ) + + accumulate_usage(total_usage, llm_response.get("usage", {})) + + choice = llm_response.get("choices", [{}])[0] + finish_reason = choice.get("finish_reason", "?") + message = choice.get("message", {}) + tool_calls = message.get("tool_calls", []) + has_content = bool(message.get("content")) + logger.info( + "Critic iteration %d: finish_reason=%s, tool_calls=%d, has_content=%s", + current_iteration, + finish_reason, + len(tool_calls), + has_content, + ) + + if not tool_calls: + # The Critic must terminate via submit_evaluation, never via + # free text. Text-only output is a failure. + content = (message.get("content") or "")[:500] + logger.warning( + "Critic iteration %d: produced text without submitting evaluation — %d chars", + current_iteration, + len(message.get("content") or ""), + ) + result.steps.append( + CriticStep( + step_type="output", + content=content, + duration_ms=elapsed_ms, + ) + ) + result.error = "critic produced text without submitting evaluation" + result.iterations = current_iteration + result.usage = total_usage + notify( + "Critic produced text without submitting evaluation.", + pct=pct, + ) + return result + + logger.info( + "Critic iteration %d: processing %d tool call(s): [%s]", + current_iteration, + len(tool_calls), + ", ".join( + tc.get("function", {}).get("name", "?") for tc in tool_calls + ), + ) + messages.append(message) + + terminal_success = False + for tc_idx, tc in enumerate(tool_calls, 1): + func = tc.get("function", {}) + tool_name = func.get("name", "") + raw_args = func.get("arguments", "{}") + tool_id = tc.get("id", "") + + try: + arguments = json.loads(raw_args) + except json.JSONDecodeError: + arguments = {} + + logger.info( + "Critic iteration %d: calling tool '%s' (%d/%d)", + current_iteration, + tool_name, + tc_idx, + len(tool_calls), + ) + + if tool_name == "submit_evaluation": + notify( + f"Submitting evaluation for {item_uri}…", pct=pct + ) + elif tool_name == "sample_table": + fn = arguments.get("full_name", "?") + notify(f"Sampling {fn}…", pct=pct) + elif tool_name == "execute_sql": + sql_preview = arguments.get("sql", "")[:80] + notify(f"Running SQL: {sql_preview}…", pct=pct) + elif tool_name == "get_documents_context": + notify("Retrieving documents…", pct=pct) + else: + notify(f"Calling {tool_name}…", pct=pct) + + result.steps.append( + CriticStep( + step_type="tool_call", + content=json.dumps(arguments, default=str)[:500], + tool_name=tool_name, + ) + ) + + t1 = time.time() + tool_result = dispatch_tool( + TOOL_HANDLERS, ctx, tool_name, arguments, trace_name=_TRACE_NAME + ) + tool_ms = int((time.time() - t1) * 1000) + + logger.info( + "Critic iteration %d: tool '%s' returned %d chars in %dms", + current_iteration, + tool_name, + len(tool_result), + tool_ms, + ) + + result.steps.append( + CriticStep( + step_type="tool_result", + content=( + (tool_result[:500] + "…") + if len(tool_result) > 500 + else tool_result + ), + tool_name=tool_name, + duration_ms=tool_ms, + ) + ) + + messages.append( + { + "role": "tool", + "tool_call_id": tool_id, + "content": tool_result, + } + ) + + # Detect terminal success: submit_evaluation returned success=True + # AND stamped an EvalReport onto the context. An invalid status + # (the handler returns success=False) does NOT terminate the + # loop — the agent continues so it can resubmit a valid verdict. + if tool_name == "submit_evaluation": + try: + parsed = json.loads(tool_result) + except json.JSONDecodeError: + parsed = {} + if ( + parsed.get("success") is True + and ctx.semantic_eval_report is not None + ): + terminal_success = True + logger.info( + "Critic iteration %d: submit_evaluation succeeded — terminating", + current_iteration, + ) + + if terminal_success: + result.success = True + result.report = ctx.semantic_eval_report + result.iterations = current_iteration + result.usage = total_usage + logger.info( + "===== CRITIC COMPLETE ===== uri=%s, status=%s, bubble=%s, " + "iterations=%d, prompt_tokens=%d, completion_tokens=%d", + item_uri, + result.report.status if result.report else "?", + result.report.bubble_to_planner if result.report else "?", + result.iterations, + total_usage["prompt_tokens"], + total_usage["completion_tokens"], + ) + notify(f"Critic verdict submitted for {item_uri}.", pct=100) + return result + + # Budget exhausted without a successful submit. + result.iterations = iteration_limit + result.usage = total_usage + result.error = "critic exhausted iteration budget" + logger.error("===== CRITIC FAILED ===== %s", result.error) + notify(result.error, pct=95) + return result diff --git a/src/agents/agent_mapping_pge/evaluator/deterministic.py b/src/agents/agent_mapping_pge/evaluator/deterministic.py new file mode 100644 index 00000000..b8e72951 --- /dev/null +++ b/src/agents/agent_mapping_pge/evaluator/deterministic.py @@ -0,0 +1,539 @@ +"""Deterministic (stage-1) evaluator for submitted mappings. + +This module is pure-Python and has no LLM dependency. It runs the +submitted mapping's SQL through a caller-supplied ``execute_sql_fn`` and +checks structural invariants (row count, distinct id count, dangling +foreign-key fractions, etc.). + +``execute_sql_fn`` contract:: + + def execute_sql_fn(sql: str) -> dict +returning ``{"columns": [...], "rows": [{col: value, ...}, ...]}``. + +Important: this is the *full* result set, not the 3-row sample emitted by +:func:`agents.tools.sql.tool_execute_sql`. The orchestrator (Sprint 7) is +responsible for plugging in a runner that returns full rows — typically a +thin wrapper around ``DatabricksClient.execute_query``. + +All checks compute every metric even when some fail; the resulting +:class:`~agents.agent_mapping_pge.contracts.EvalReport` lists every failure +so the Generator/Planner can address them in one shot. +""" + +from typing import Any, Callable, Dict, List, Optional, Tuple + +from back.core.logging import get_logger +from agents.agent_mapping_pge.contracts import EvalFailure, EvalReport +from agents.agent_mapping_pge.evaluator.report import build_report + +logger = get_logger(__name__) + +# Thresholds for stage-1 checks. These are intentionally lax — the +# semantic evaluator (stage 2) catches subtler issues. +_DANGLING_FK_FAIL_THRESHOLD = 0.05 +_DANGLING_FK_BUBBLE_THRESHOLD = 0.5 + + +SqlFn = Callable[[str], dict] + + +# ===================================================== +# Helpers +# ===================================================== + + +def _resolve_id_col(mapping: dict, fallback: str = "ID") -> str: + """Return the column name that holds the entity identifier in the row dicts.""" + return mapping.get("id_column") or fallback + + +def _extract_id_values(rows: List[dict], id_col: str) -> List[Any]: + """Pull the id_col value from each row; missing key -> ``None``.""" + return [r.get(id_col) for r in rows] + + +def _attribute_names(ontology_class: dict) -> List[str]: + """Ontology attributes can come in a few shapes; normalise to a list of names.""" + attrs = ontology_class.get("attributes") or [] + out: List[str] = [] + for a in attrs: + if isinstance(a, str): + out.append(a) + elif isinstance(a, dict): + name = a.get("name") or a.get("uri") or a.get("label") + if name: + out.append(name) + return out + + +def _fail( + *, + check: str, + expected: str, + observed: str, + hint: str, + kind: str = "structural", +) -> EvalFailure: + return EvalFailure( + kind=kind, check=check, expected=expected, observed=observed, hint=hint + ) + + +class _SqlExecError(Exception): + """A generated mapping's SQL parsed but failed at execution time. + + Wraps the underlying DB driver exception so the deterministic evaluator + can convert it into an actionable FAIL — never let it crash the run. + """ + + +def _exec(execute_sql_fn: SqlFn, sql: str) -> dict: + """Run SQL, normalising any driver-level failure into ``_SqlExecError``. + + Generated mappings routinely produce SQL that *parses* but fails at + execution (UNION column-type mismatch, invalid CAST, unknown column). + The PGE contract is that such errors become feedback for the generator, + so they must surface as a FAIL report rather than an unhandled exception + that aborts the whole agent run. + """ + try: + return execute_sql_fn(sql) or {} + except Exception as exc: # noqa: BLE001 — any driver error becomes feedback + raise _SqlExecError(str(exc)) from exc + + +def _sql_error_report(*, item: str, sql_error: str) -> EvalReport: + """Build a FAIL report for a mapping whose SQL failed to execute. + + ``bubble_to_planner`` stays False: a runtime SQL error is the + Generator's to fix (align types, correct columns), not a signal that the + Planner's source model is wrong. + """ + # Keep the hint compact — driver errors can be very long. + err = sql_error.strip().splitlines()[0][:300] if sql_error else "unknown error" + return build_report( + stage="deterministic", + metrics={"sql_error": err}, + failures=[ + _fail( + check="sql_execution", + expected="SQL executes without error", + observed="execution error", + hint=( + f"The mapping SQL for '{item}' failed to execute: {err}. " + "Fix the SQL — e.g. align UNION branch column types with " + "explicit CAST (a common cause is one branch typing a " + "column as BIGINT and another as STRING/NULL), correct " + "column names, or use try_cast for malformed values." + ), + ) + ], + bubble_to_planner=False, + ) + + +# ===================================================== +# Entity evaluator +# ===================================================== + + +def evaluate_entity_mapping( + *, + mapping: dict, + ontology_class: dict, + execute_sql_fn: SqlFn, +) -> EvalReport: + """Run the stage-1 deterministic checks on a submitted entity mapping. + + Args: + mapping: Submitted entity mapping in the shape produced by + ``tool_submit_entity_mapping``. + ontology_class: The ontology-class dict the mapping targets; must + expose an ``attributes`` list (each item being a name string or + a dict with a ``name`` key). + execute_sql_fn: Caller-supplied SQL runner — see module docstring. + + Returns: + An :class:`EvalReport` summarising the metrics and any failures. + ``bubble_to_planner`` is set when ``row_count == 0`` (typically + means the mapping is querying the wrong table altogether). + """ + class_name = mapping.get("class_name") or ontology_class.get("name") or "?" + sql = mapping.get("sql_query", "") + id_col = _resolve_id_col(mapping) + logger.info( + "evaluate_entity_mapping: class=%s, id_col=%s, sql_len=%d", + class_name, + id_col, + len(sql), + ) + + try: + result = _exec(execute_sql_fn, sql) + except _SqlExecError as exc: + logger.warning( + "evaluate_entity_mapping: class=%s SQL failed to execute: %s", + class_name, + exc, + ) + return _sql_error_report(item=class_name, sql_error=str(exc)) + rows = result.get("rows", []) or [] + row_count = len(rows) + + id_values = _extract_id_values(rows, id_col) + null_id_count = sum(1 for v in id_values if v is None) + distinct_id_count = len({v for v in id_values if v is not None}) + + raw_unmapped = mapping.get("unmapped_attributes") or [] + declared_unmapped: set = set() + for item in raw_unmapped: + if isinstance(item, dict): + name = item.get("name") + if name: + declared_unmapped.add(str(name)) + elif item is not None: + declared_unmapped.add(str(item)) + declared_mapped = set((mapping.get("attribute_mappings") or {}).keys()) + all_attrs = _attribute_names(ontology_class) + unmapped_attrs = [ + a for a in all_attrs if a not in declared_mapped and a not in declared_unmapped + ] + unmapped_pct = (len(unmapped_attrs) / len(all_attrs)) if all_attrs else 0.0 + + metrics: Dict[str, Any] = { + "row_count": row_count, + "distinct_id_count": distinct_id_count, + "null_id_count": null_id_count, + "unmapped_attribute_pct": unmapped_pct, + "unmapped_attributes": unmapped_attrs, + } + + failures: List[EvalFailure] = [] + bubble = False + + if row_count == 0: + failures.append( + _fail( + check="row_count", + expected="> 0", + observed="0", + hint=( + f"Entity '{class_name}' SQL returned 0 rows. Check the FROM " + "table is correct and the WHERE clause is not over-filtering." + ), + ) + ) + bubble = True + + if row_count > 0 and distinct_id_count != row_count: + dupes = row_count - distinct_id_count + failures.append( + _fail( + check="distinct_id_count", + expected=f"== row_count ({row_count})", + observed=str(distinct_id_count), + hint=( + f"{dupes} duplicate '{id_col}' value(s) in entity '{class_name}'. " + "Add DISTINCT or use a stricter id column." + ), + ) + ) + + if null_id_count > 0: + failures.append( + _fail( + check="null_id_count", + expected="== 0", + observed=str(null_id_count), + hint=( + f"{null_id_count} row(s) have NULL '{id_col}' in entity " + f"'{class_name}'. Add 'WHERE {id_col} IS NOT NULL' to the SQL." + ), + ) + ) + + if unmapped_pct > 0: + failures.append( + _fail( + check="unmapped_attribute_pct", + expected="== 0", + observed=f"{unmapped_pct:.3f}", + hint=( + f"{len(unmapped_attrs)} attribute(s) of '{class_name}' are " + f"neither in attribute_mappings nor declared in " + f"unmapped_attributes: {unmapped_attrs}. Map them, or list " + "them explicitly under 'unmapped_attributes'." + ), + ) + ) + + logger.info( + "evaluate_entity_mapping: class=%s -> %s (%d failure(s), bubble=%s)", + class_name, + "PASS" if not failures else "FAIL", + len(failures), + bubble, + ) + return build_report( + stage="deterministic", + metrics=metrics, + failures=failures, + bubble_to_planner=bubble, + ) + + +# ===================================================== +# Relationship evaluator +# ===================================================== + + +def _distinct_id_set( + entity_mapping: dict, + execute_sql_fn: SqlFn, + id_universe_cache: Optional[Dict[str, set]] = None, +) -> set: + """Materialise the set of valid ids for a given entity mapping. + + When ``id_universe_cache`` is provided it is consulted/populated keyed + by the entity mapping's SQL string, avoiding redundant SQL execution + across repeated calls that share endpoint entities. + """ + sql = entity_mapping.get("sql_query", "") + id_col = _resolve_id_col(entity_mapping) + if id_universe_cache is not None and sql in id_universe_cache: + return id_universe_cache[sql] + result = _exec(execute_sql_fn, sql) # may raise _SqlExecError + rows = result.get("rows", []) or [] + ids = {r.get(id_col) for r in rows if r.get(id_col) is not None} + if id_universe_cache is not None: + id_universe_cache[sql] = ids + return ids + + +def _resolve_edge_columns(mapping: dict) -> Tuple[str, str]: + """Return ``(source_col, target_col)`` for a relationship mapping.""" + return ( + mapping.get("source_id_column") or "source_id", + mapping.get("target_id_column") or "target_id", + ) + + +def evaluate_relationship_mapping( + *, + mapping: dict, + source_entity_mapping: dict, + target_entity_mapping: dict, + execute_sql_fn: SqlFn, + expected_cross_source_overlap_band: Optional[Tuple[float, float]] = None, + id_universe_cache: Optional[Dict[str, set]] = None, +) -> EvalReport: + """Run stage-1 deterministic checks on a relationship mapping. + + Checks: + + * ``total_edges > 0`` + * ``dangling_source_pct < 0.05`` — fraction of source ids that do not + exist in the source entity's id universe. + * ``dangling_target_pct < 0.05`` — same for targets. + * If ``expected_cross_source_overlap_band`` is supplied, the realised + ``overlap_pct`` (fraction of edges whose target id appears in the + target entity universe) must fall inside the band. + + ``bubble_to_planner`` is set when ``total_edges == 0``, when the source + dangling fraction exceeds ``0.5``, or when the target dangling fraction + exceeds ``0.5`` *and* the realised overlap is materially worse than the + Planner predicted (either no band was supplied, or the band check + itself failed). These cases typically indicate the relationship was + built off the wrong join key. + + Args: + id_universe_cache: Optional caller-managed dict mapping an entity + mapping's ``sql_query`` string to its materialised set of ids. + When provided, repeated calls across relationships that share + endpoint entities reuse cached id universes instead of + re-running the entity SQL via ``execute_sql_fn``. When + ``None`` (default) behaviour is unchanged — fetch fresh each + call. No module-level state is involved. + """ + name = mapping.get("property_name") or mapping.get("property") or "?" + sql = mapping.get("sql_query", "") + src_col, tgt_col = _resolve_edge_columns(mapping) + logger.info( + "evaluate_relationship_mapping: property=%s, src_col=%s, tgt_col=%s", + name, + src_col, + tgt_col, + ) + + try: + edges_result = _exec(execute_sql_fn, sql) + edge_rows = edges_result.get("rows", []) or [] + total_edges = len(edge_rows) + + source_universe = _distinct_id_set( + source_entity_mapping, execute_sql_fn, id_universe_cache + ) + target_universe = _distinct_id_set( + target_entity_mapping, execute_sql_fn, id_universe_cache + ) + except _SqlExecError as exc: + logger.warning( + "evaluate_relationship_mapping: property=%s SQL failed to execute: %s", + name, + exc, + ) + return _sql_error_report(item=name, sql_error=str(exc)) + + src_values = [r.get(src_col) for r in edge_rows] + tgt_values = [r.get(tgt_col) for r in edge_rows] + + if total_edges > 0: + dangling_src = sum( + 1 for v in src_values if v is None or v not in source_universe + ) + dangling_tgt = sum( + 1 for v in tgt_values if v is None or v not in target_universe + ) + dangling_src_pct = dangling_src / total_edges + dangling_tgt_pct = dangling_tgt / total_edges + overlap_pct = 1.0 - dangling_tgt_pct + else: + dangling_src_pct = 0.0 + dangling_tgt_pct = 0.0 + overlap_pct = 0.0 + + metrics: Dict[str, Any] = { + "total_edges": total_edges, + "dangling_source_pct": dangling_src_pct, + "dangling_target_pct": dangling_tgt_pct, + "cross_source_overlap_pct": overlap_pct, + "source_universe_size": len(source_universe), + "target_universe_size": len(target_universe), + } + + failures: List[EvalFailure] = [] + bubble = False + + if total_edges == 0: + failures.append( + _fail( + check="total_edges", + expected="> 0", + observed="0", + hint=( + f"Relationship '{name}' produced 0 edges. Confirm the join " + "predicate is on the right columns and rows are not being " + "filtered away." + ), + ) + ) + bubble = True + + if total_edges > 0 and dangling_src_pct >= _DANGLING_FK_FAIL_THRESHOLD: + failures.append( + _fail( + check="dangling_source_pct", + expected=f"< {_DANGLING_FK_FAIL_THRESHOLD}", + observed=f"{dangling_src_pct:.3f}", + hint=( + f"{dangling_src_pct:.1%} of source_id values in relationship " + f"'{name}' are absent from the mapped source entity. The " + "source entity's id_column is usually an ALIAS for a derived " + "expression (e.g. CONCAT(regexp_extract(,'...'),'-x')). " + "Reproduce that exact id expression from the source entity's " + "SQL for source_id — do not select a raw/trust-local column." + ), + ) + ) + if dangling_src_pct > _DANGLING_FK_BUBBLE_THRESHOLD: + bubble = True + + # When an explicit cross-source overlap band is provided the relationship + # is *expected* to be partial (e.g. trust_a-only IDs vs the cross-trust + # canonical universe). In that case we trust the band check and skip + # the standard ``dangling_target_pct`` strictness — the partiality is + # the point. The catastrophic-dangling bubble below still fires, but + # only when the band itself ALSO fails (i.e. the realised overlap is + # materially worse than the Planner predicted). + if ( + total_edges > 0 + and dangling_tgt_pct >= _DANGLING_FK_FAIL_THRESHOLD + and expected_cross_source_overlap_band is None + ): + failures.append( + _fail( + check="dangling_target_pct", + expected=f"< {_DANGLING_FK_FAIL_THRESHOLD}", + observed=f"{dangling_tgt_pct:.3f}", + hint=( + f"{dangling_tgt_pct:.1%} of target_id values in relationship " + f"'{name}' are absent from the mapped target entity. The " + "target entity's id_column is usually an ALIAS for a derived " + "expression; reproduce that exact id expression from the " + "target entity's SQL for target_id — not a raw join column." + ), + ) + ) + + band_failed = False + if expected_cross_source_overlap_band is not None: + lo, hi = expected_cross_source_overlap_band + if not (lo <= overlap_pct <= hi): + band_failed = True + failures.append( + _fail( + check="cross_source_overlap_pct", + expected=f"in [{lo:.3f}, {hi:.3f}]", + observed=f"{overlap_pct:.3f}", + hint=( + f"Cross-source overlap for '{name}' is {overlap_pct:.1%}, " + f"outside the expected band [{lo:.1%}, {hi:.1%}]. " + "Check the join key and the source/target trust assignments." + ), + ) + ) + + # Bubble-to-planner on catastrophic target-dangling, with a band-aware gate. + # + # * Band absent + dangling > 0.5: the strict dangling_target_pct failure + # above already fired; we just flip the bubble flag (no new row needed). + # * Band present + band PASSED: the Planner predicted this overlap and + # was right — do NOT bubble, even if dangling > 0.5 (the partiality + # was expected). + # * Band present + band FAILED + dangling > 0.5: the realised overlap + # is materially worse than predicted. Bubble, and emit a dedicated + # ``dangling_target_pct_catastrophic`` failure so the FAIL report has + # a concrete structural row alongside the band-check failure. + if total_edges > 0 and dangling_tgt_pct > _DANGLING_FK_BUBBLE_THRESHOLD: + if expected_cross_source_overlap_band is None: + bubble = True + elif band_failed: + bubble = True + failures.append( + _fail( + check="dangling_target_pct_catastrophic", + expected=f"<= {_DANGLING_FK_BUBBLE_THRESHOLD}", + observed=f"{dangling_tgt_pct:.3f}", + hint=( + f"{dangling_tgt_pct:.1%} of target_id values in " + f"relationship '{name}' are absent from the mapped " + "target entity AND the realised overlap is outside " + "the predicted band. Re-plan the join key and the " + "source/target trust assignments." + ), + ) + ) + + logger.info( + "evaluate_relationship_mapping: %s -> %s (%d failure(s), bubble=%s)", + name, + "PASS" if not failures else "FAIL", + len(failures), + bubble, + ) + return build_report( + stage="deterministic", + metrics=metrics, + failures=failures, + bubble_to_planner=bubble, + ) diff --git a/src/agents/agent_mapping_pge/evaluator/report.py b/src/agents/agent_mapping_pge/evaluator/report.py new file mode 100644 index 00000000..532f0a7d --- /dev/null +++ b/src/agents/agent_mapping_pge/evaluator/report.py @@ -0,0 +1,37 @@ +"""Small helpers for assembling :class:`EvalReport` objects. + +The dataclasses themselves live in +:mod:`agents.agent_mapping_pge.contracts`; this module just centralises the +"compose a report from a list of failures" boilerplate so the deterministic +and (future) semantic evaluators stay short. +""" + +from typing import Any, Dict, List + +from back.core.logging import get_logger +from agents.agent_mapping_pge.contracts import EvalFailure, EvalReport + +logger = get_logger(__name__) + + +def build_report( + *, + stage: str, + metrics: Dict[str, Any], + failures: List[EvalFailure], + bubble_to_planner: bool, +) -> EvalReport: + """Assemble an :class:`EvalReport`; status is derived from ``failures``.""" + status = "PASS" if not failures else "FAIL" + if bubble_to_planner and status == "PASS": + logger.warning( + "build_report: bubble_to_planner=True but no failures → demoted " + "to False; check caller logic" + ) + return EvalReport( + status=status, + stage=stage, + metrics=dict(metrics), + failures=list(failures), + bubble_to_planner=bool(bubble_to_planner) and status == "FAIL", + ) diff --git a/src/agents/agent_mapping_pge/generators/__init__.py b/src/agents/agent_mapping_pge/generators/__init__.py new file mode 100644 index 00000000..575f858c --- /dev/null +++ b/src/agents/agent_mapping_pge/generators/__init__.py @@ -0,0 +1,31 @@ +"""Generator agents for the mapping-PGE pipeline. + +Each Generator is a narrow tool-calling agent that maps ONE ontology item +(class or relationship) at a time. The orchestrator (Sprint 7) calls them +per-item with a filtered slice of the Planner's :class:`SourceModel` — the +Generators never see the full ontology or full metadata, keeping each +decision cheap and local. + +* Sprint 4 — :mod:`agents.agent_mapping_pge.generators.entity`. +* Sprint 5 — :mod:`agents.agent_mapping_pge.generators.relationship`. +""" + +from agents.agent_mapping_pge.generators.entity import ( + EntityGenResult, + EntityGenStep, + run_entity_generator, +) +from agents.agent_mapping_pge.generators.relationship import ( + RelationshipGenResult, + RelationshipGenStep, + run_relationship_generator, +) + +__all__ = [ + "EntityGenResult", + "EntityGenStep", + "run_entity_generator", + "RelationshipGenResult", + "RelationshipGenStep", + "run_relationship_generator", +] diff --git a/src/agents/agent_mapping_pge/generators/entity.py b/src/agents/agent_mapping_pge/generators/entity.py new file mode 100644 index 00000000..973e9599 --- /dev/null +++ b/src/agents/agent_mapping_pge/generators/entity.py @@ -0,0 +1,812 @@ +""" +OntoBricks Mapping-PGE EntityGenerator Agent. + +Sprint 4 of the Planner-Generator-Evaluator (PGE) redesign. + +The EntityGenerator is a narrow, focused LLM agent that maps **one** ontology +class at a time. The orchestrator (Sprint 7) calls it per item with a +filtered slice of the Planner's :class:`SourceModel`: + +* the single ontology class to map, with its full attribute list, and +* a small SourceModel slice — only the candidate tables / canonical IDs / + joins that are relevant to *this* class. + +The Generator does NOT see the full ontology or full metadata. That is the +core design contract: keep its context bounded and each decision cheap. + +The loop shape mirrors :mod:`agents.agent_mapping_pge.planner` — same +``call_serving_endpoint`` + ``dispatch_tool`` ReAct cycle, same 3-second +inter-iteration delay, same MLflow trace decorator — with these differences: + +* Smaller default budget (12 vs 25): mapping one class is bounded work. +* Different tool set: only ``execute_sql``, ``sample_table``, and the + terminal ``submit_entity_mapping``. The slice already carries every piece + of context the Generator needs. +* No single-shot fallback: if the endpoint refuses tools, the Generator + reports failure — it produces structured output through + ``submit_entity_mapping`` only. +* The "NO SILENT DROPS" invariant: every ontology attribute must be either + in ``attribute_mappings`` or in ``unmapped_attributes`` with a one-sentence + reason. The system prompt enforces this; the tool persists it. +""" + +import json +import time +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional + +import requests + +from back.core.logging import get_logger +from agents.engine_base import ( + call_serving_endpoint, + dispatch_tool, + accumulate_usage, +) +from agents.tools.context import ToolContext +from agents.tools.mapping import ( + MAPPING_TOOL_DEFINITIONS_BY_NAME, + MAPPING_TOOL_HANDLERS, +) +from agents.tools.planner import ( + SAMPLE_TABLE_DEF, + tool_sample_table, +) +from agents.tools.sql import ( + SQL_TOOL_DEFINITIONS, + SQL_TOOL_HANDLERS, +) +from agents.tracing import trace_agent + +logger = get_logger(__name__) + +MAX_ITERATIONS = 12 +LLM_TIMEOUT = 180 +_ITERATION_DELAY_SEC = 1 +# See planner._MAX_TOKENS comment — same rationale for the Generator's +# submit_entity_mapping JSON (SQL + attribute_mappings can be large). +_MAX_TOKENS = 50000 + +_TRACE_NAME = "mapping_pge_entity_generator" + + +# ===================================================== +# Tool aggregation +# ===================================================== +# +# The EntityGenerator only needs: +# * execute_sql – validate the composed SELECT before submitting. +# * sample_table – disambiguate when two candidate tables are equally +# plausible (e.g. same confidence in the slice). +# * submit_entity_mapping – TERMINAL. +# +# We deliberately exclude: +# * get_ontology / get_metadata / get_documents_context — the Planner's +# view; the slice already has what's needed. +# * column_value_overlap / distinct_count — those validate join keys and +# canonical IDs, which the Planner already locked in. +# * submit_relationship_mapping / submit_source_model — wrong stage. + +# Filter MAPPING_TOOL_DEFINITIONS down to just submit_entity_mapping. We +# look up by name from the by-name index in ``mapping.py`` rather than +# scanning the list inline. Sprint 5 will reuse the same pattern for +# ``submit_relationship_mapping``. +_SUBMIT_ENTITY_DEF: dict = MAPPING_TOOL_DEFINITIONS_BY_NAME["submit_entity_mapping"] + +TOOL_DEFINITIONS: List[dict] = ( + SQL_TOOL_DEFINITIONS + + [SAMPLE_TABLE_DEF] + + [_SUBMIT_ENTITY_DEF] +) + +TOOL_HANDLERS: Dict[str, Callable] = { + **SQL_TOOL_HANDLERS, + "sample_table": tool_sample_table, + "submit_entity_mapping": MAPPING_TOOL_HANDLERS["submit_entity_mapping"], +} + + +# ===================================================== +# Data classes +# ===================================================== + + +@dataclass +class EntityGenStep: + """One observable step of the EntityGenerator's execution. + + Mirrors :class:`agents.agent_mapping_pge.planner.PlannerStep` but is + scoped to the Generator so the orchestrator (Sprint 7) can render a + per-class timeline in the UI. + """ + + step_type: str # "tool_call" | "tool_result" | "output" + content: str + tool_name: str = "" + duration_ms: int = 0 + + +@dataclass +class EntityGenResult: + """Outcome of a single EntityGenerator invocation. + + ``mapping`` holds the submitted entity-mapping dict (the same shape the + handler appends to ``ctx.entity_mappings``) when ``success`` is True. + """ + + success: bool + mapping: Optional[dict] = None + steps: List[EntityGenStep] = field(default_factory=list) + iterations: int = 0 + error: str = "" + usage: Dict[str, int] = field(default_factory=dict) + + +# ===================================================== +# System prompt +# ===================================================== +# +# The ENTITY SQL RULES section is lifted verbatim from the legacy in-house +# mapping agent (the section starting "SQL RULES FOR ENTITIES") because +# those rules are correct and load-bearing — every mapping query must +# follow them or downstream SPARQL translation breaks. +# +# The PGE-specific additions are the slice-consumption rules: pick the +# best candidate table from the slice, use the canonical ID exactly as +# the Planner specified it, and account for every ontology attribute. + +SYSTEM_PROMPT = """\ +You are a senior data engineer. Your job is to map ONE ontology class to a \ +single SQL SELECT query against a Databricks source table, validated against \ +real data via execute_sql, and submitted via submit_entity_mapping. + +YOU WILL BE GIVEN +• ontology_class: the class to map (uri, label, comment, attributes list). +• source_model_slice: a small JSON object the Planner already curated for \ +this class: + - candidate_tables[]: {table, confidence, reason} — the tables that could \ +realise this class. + - canonical_id.canonical_column_per_table[
]: the expression that \ +MUST be aliased AS ID for each table. THIS VALUE MAY BE A BARE COLUMN \ +NAME ("MOTHER_NHS_NO") OR A FULL SQL EXPRESSION \ +("regexp_extract(pregnancy_id, '([a-f0-9-]+-preg-[0-9]+)')"). Drop it \ +verbatim into the SELECT and alias it AS ID — do NOT rewrite it, do NOT \ +pick a different column, do NOT strip the function call. The Planner emits \ +SQL expressions when raw column values across trusts are in different \ +formats and need to be normalized to a common canonical key. + - canonical_id.format_note: a one-sentence note describing the canonical \ +key (may be empty). Read it to understand what each row's ID represents. + - relevant_joins[]: optional — any joins the Planner thinks may apply. + +SINGLE-SOURCE vs CROSS-SOURCE (CRITICAL — read carefully) +The number of entries in canonical_id.canonical_column_per_table is the \ +authoritative signal for how to shape your SELECT: + + • If canonical_column_per_table has EXACTLY ONE table → single-source \ +class. Write a flat SELECT from that one table. Pick it from the matching \ +candidate_tables entry. + + • If canonical_column_per_table has TWO OR MORE tables → CROSS-SOURCE \ +class (e.g. the same patient or pregnancy realised across multiple trusts). \ +You MUST emit a UNION ALL across ALL listed tables, NOT pick one. Each \ +branch uses that table's canonical-ID column AS ID. Picking just one would \ +produce a Mother (or Pregnancy, etc.) entity that's missing 60–70% of its \ +real instances, and every relationship pointing at it would then dangle. \ +This is the #1 failure mode the orchestrator catches — do not produce it. + + UNION shape (use exactly this pattern — substitute the canonical-ID \ +EXPRESSION exactly as the Planner specified it for that table; do NOT \ +rewrite it): + SELECT AS ID, AS Label, \ + FROM WHERE IS NOT NULL + UNION ALL + SELECT AS ID, AS Label, \ + FROM WHERE IS NOT NULL + UNION ALL + ... + + All branches must return the SAME columns in the SAME order, AND each \ +column must have the SAME TYPE in every branch. If a branch lacks a column \ +another branch has, project a NULL with a matching alias **cast to the same \ +type the real branch uses** (e.g. if branch A has ``BABY_NHS_NO`` typed \ +BIGINT, branch B must use ``CAST(NULL AS BIGINT) AS BABY_NHS_NO`` — not \ +``AS STRING``). When two branches hold the column with DIFFERENT types, cast \ +BOTH to a common type (``CAST(... AS STRING)`` is the safe default). A \ +``CAST_INVALID_INPUT`` / type-mismatch error from execute_sql always means a \ +column's types differ across branches — fix the casts, do not change the ID. + +TOOLS +You have three tools: + • execute_sql – Validate the composed SELECT before submitting. \ +The tool runs your query with a small LIMIT and returns columns + sample \ +rows; the persisted mapping has no LIMIT. + • sample_table – Up to N random rows from a table. Use only when \ +two candidate tables are equally plausible and you need to peek at real \ +values to disambiguate. + • submit_entity_mapping – TERMINAL. Call exactly once, after execute_sql \ +succeeds, with the full mapping payload. + +SQL RULES FOR ENTITIES (CRITICAL) +• Always use the full table name from the slice (catalog.schema.table). +• The FIRST column MUST be aliased AS ID — it MUST be the canonical-ID \ +column the slice specifies for the chosen table. +• The SECOND column MUST be aliased AS Label — pick the most human-readable \ +available column (typically ``name``, ``label``, ``display_name``, or \ +similar). If no human-readable column exists, fall back to the canonical \ +ID column itself aliased AS Label. +• Add one column per ontology data-property attribute you can satisfy from \ +the chosen table. Use the column's original name (no alias). +• If the same column serves as both an alias and an attribute, include it \ +twice: once with the alias (AS ID or AS Label) and once with its original \ +name so it appears in attribute_mappings. +• Add WHERE IS NOT NULL to filter null keys. When the ID is a \ +derived expression, also exclude empty extractions (e.g. \ +``WHERE regexp_extract(...) <> ''``). +• DEDUP COLLAPSED KEYS: when the canonical-ID is a derived EXPRESSION that \ +can repeat across rows (e.g. a ``-del`` key where several episode rows \ +share one pregnancy core), the same ID will appear on multiple rows and the \ +evaluator FAILs on "duplicate ID values". Make each node id unique: wrap the \ +UNION in ``SELECT ... FROM () GROUP BY ID`` (taking MAX() of each \ +attribute) or use ``SELECT DISTINCT`` when there are no attributes. The id \ +column must have exactly one row per distinct value. +• Do NOT add LIMIT — the persisted mapping query must return ALL rows. \ +execute_sql adds a small LIMIT internally for validation only. +• Do NOT use ORDER BY, CTEs, or subqueries unless absolutely necessary. +• Write simple, flat SELECT statements. + +ATTRIBUTE COVERAGE — NO SILENT DROPS (CRITICAL) +For EACH ontology attribute on the class, you must do ONE of: + (a) include a SQL column for it in the SELECT, AND add an entry to \ +attribute_mappings mapping the ontology attribute name to the SQL column \ +name (case-sensitive); OR + (b) add it to unmapped_attributes with a one-sentence reason, using the \ +shape {"name": "", "reason": ""}. + +You may NOT silently drop an attribute. The orchestrator will reject any \ +mapping where some ontology attributes appear in neither list. If a column \ +genuinely does not exist on the chosen table, that's an honest unmapped — \ +say so in the reason. + +WORKFLOW +1. Read the ontology class and the source_model_slice carefully. +2. COUNT the entries in canonical_id.canonical_column_per_table: + - one → single-source: pick that table, compose a flat SELECT. + - two or more → cross-source: compose a UNION ALL across ALL of them \ +(see the SINGLE-SOURCE vs CROSS-SOURCE block above). Do NOT pick one. +3. Compose the SELECT (or UNION ALL) following the SQL RULES above. For \ +each branch, the value of canonical_column_per_table[] is what \ +gets aliased AS ID — drop it in verbatim. It may already be a SQL \ +expression (e.g. ``regexp_extract(...)``); do not rewrite it. +4. Call execute_sql to validate the SELECT. If it fails, READ the error and \ +fix the SQL (typically a typo'd column name, mismatched column lists in a \ +UNION, or wrong full_name). Retry as needed. Never submit an un-validated \ +query. +5. Once execute_sql succeeds, call submit_entity_mapping EXACTLY ONCE with: + class_uri, class_name, sql_query (no LIMIT), id_column, label_column, \ +attribute_mappings, unmapped_attributes. +6. That's the terminal step. Do not emit any free text after submitting. + +GENERAL RULES +• Only ever pass row-returning queries (SELECT / WITH …) to execute_sql. +• Do not call get_metadata, get_ontology, or any other tool — they are not \ +available to you. The slice carries everything you need. +• If a retry_hint is present at the top of the user message, treat it as \ +authoritative — your previous attempt failed for the reason stated and you \ +should NOT repeat the same mistake. +""" + + +# ===================================================== +# Internal helpers +# ===================================================== + + +def _build_user_prompt( + ontology_class: dict, + source_model_slice: dict, + retry_hint: Optional[str] = None, +) -> str: + """Render the per-class user prompt. + + The orchestrator hands us `ontology_class` and a focused + `source_model_slice`. We emit a structured prompt that: + * surfaces the retry hint up top if one was provided, + * lists the class metadata and attribute list explicitly so the LLM + cannot forget any attribute, and + * embeds the slice as JSON so the LLM can refer to it precisely. + """ + parts: List[str] = [] + + if retry_hint: + parts.append(f"RETRY HINT (authoritative): {retry_hint}") + parts.append("") + + class_uri = ontology_class.get("uri", "") + class_label = ontology_class.get("label") or ontology_class.get("name", "") + class_comment = ontology_class.get("comment", "") or "" + attributes = ontology_class.get("attributes", []) or [] + + attr_summary_lines: List[str] = [] + for attr in attributes: + if isinstance(attr, dict): + attr_name = attr.get("name") or attr.get("label") or attr.get("uri", "?") + attr_type = attr.get("type") or attr.get("range") or "" + attr_summary_lines.append( + f" - {attr_name}" + (f" ({attr_type})" if attr_type else "") + ) + else: + attr_summary_lines.append(f" - {attr}") + + parts.append("ONTOLOGY CLASS") + parts.append(f" uri: {class_uri}") + parts.append(f" label: {class_label}") + if class_comment: + parts.append(f" comment: {class_comment}") + if attr_summary_lines: + parts.append(" attributes ({} total):".format(len(attributes))) + parts.extend(attr_summary_lines) + else: + parts.append(" attributes: (none — only ID and Label required)") + + parts.append("") + parts.append("SOURCE MODEL SLICE") + parts.append(json.dumps(source_model_slice, indent=2, default=str)) + + parts.append("") + parts.append( + "Pick the best candidate table from the slice, compose a flat SELECT " + "following the SQL RULES, validate with execute_sql, then call " + "submit_entity_mapping exactly once. Every ontology attribute must " + "appear in either attribute_mappings or unmapped_attributes — no " + "silent drops." + ) + + prompt = "\n".join(parts) + logger.debug( + "_build_user_prompt for class=%s (%d chars):\n%s", + class_uri, + len(prompt), + prompt, + ) + return prompt + + +# ===================================================== +# Public entry point +# ===================================================== + + +@trace_agent(name="mapping_pge_entity_generator") +def run_entity_generator( + host: str, + token: str, + endpoint_name: str, + client: Any, + *, + ontology_class: dict, + source_model_slice: dict, + retry_hint: Optional[str] = None, + on_step: Optional[Callable[[str, int], None]] = None, + max_iterations: int = MAX_ITERATIONS, +) -> EntityGenResult: + """Run the EntityGenerator agent for a single ontology class. + + The agent autonomously composes a SQL SELECT for ``ontology_class`` + against the candidate table(s) in ``source_model_slice``, validates the + SQL with ``execute_sql``, and submits the validated mapping via the + terminal ``submit_entity_mapping`` tool. + + Args: + host: Databricks workspace URL. + token: Bearer token for the serving endpoint. + endpoint_name: Foundation Model serving endpoint name. + client: Databricks SQL client (must expose ``execute_query(sql)``). + ontology_class: Full dict for the SINGLE class to map (uri, label, + comment, attributes list). + source_model_slice: Filtered SourceModel slice with candidate_tables, + canonical_id, and optional relevant_joins. + retry_hint: Optional one-sentence hint from the orchestrator's + previous-attempt evaluation. When present, surfaced at the top of + the user prompt. + on_step: Optional progress callback ``(msg, pct)`` for UI updates. + max_iterations: Upper bound on tool-call iterations (default 12 — + smaller than the Planner because the scope is one class). + + Returns: + An :class:`EntityGenResult`. ``success`` is True iff a mapping was + successfully submitted; in that case ``mapping`` holds the submitted + dict. On failure, ``error`` explains why and ``mapping`` is None. + """ + iteration_limit = max_iterations if max_iterations is not None else MAX_ITERATIONS + + class_uri = (ontology_class or {}).get("uri", "") + class_label = ( + (ontology_class or {}).get("label") + or (ontology_class or {}).get("name", "") + ) + n_attrs = len(((ontology_class or {}).get("attributes") or [])) + n_candidates = len(((source_model_slice or {}).get("candidate_tables") or [])) + + logger.info( + "===== ENTITY GENERATOR START ===== endpoint=%s, class=%s (%s), " + "attributes=%d, candidate_tables=%d, retry_hint=%s, max_iter=%d", + endpoint_name, + class_label, + class_uri, + n_attrs, + n_candidates, + "yes" if retry_hint else "no", + iteration_limit, + ) + + ctx = ToolContext( + host=host.rstrip("/"), + token=token, + client=client, + # The slice subsumes metadata/ontology for this agent; the unified + # ToolContext still needs these fields, so we plant the slice into + # ``metadata`` for completeness even though no handler reads it. + metadata={}, + ontology={}, + documents=[], + ) + + result = EntityGenResult(success=False) + + user_prompt = _build_user_prompt( + ontology_class or {}, source_model_slice or {}, retry_hint=retry_hint + ) + messages: List[dict] = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ] + logger.info( + "EntityGenerator conversation initialized: system=%d chars, user=%d chars", + len(SYSTEM_PROMPT), + len(user_prompt), + ) + + total_usage: Dict[str, int] = {"prompt_tokens": 0, "completion_tokens": 0} + + def _progress_pct(iteration_idx: int) -> int: + # Linear ramp 5 → 95 across the iteration budget. submit hits 100. + ratio = (iteration_idx + 1) / max(iteration_limit, 1) + return min(5 + int(ratio * 90), 95) + + def notify(msg: str, *, pct: Optional[int] = None) -> None: + actual_pct = pct if pct is not None else 5 + logger.info("ENTITY GEN STEP [%d%%] %s", actual_pct, msg) + if on_step: + on_step(msg, actual_pct) + + notify(f"Generating mapping for {class_label or class_uri}…", pct=1) + + # Snapshot the pre-existing mapping count so we can detect "this run + # added a mapping" without relying on absolute counters. (The orchestrator + # in Sprint 7 may reuse a ToolContext across calls; today's `ctx` is + # fresh, but the assertion is cheap and future-proof.) + pre_run_mapping_count = len(ctx.entity_mappings) + + # ------------------------------------------------------------------ + # Agent loop + # ------------------------------------------------------------------ + for iteration in range(iteration_limit): + if iteration > 0: + logger.debug( + "Iteration %d: waiting %ds before LLM call (rate limit mitigation)", + iteration + 1, + _ITERATION_DELAY_SEC, + ) + time.sleep(_ITERATION_DELAY_SEC) + + current_iteration = iteration + 1 + pct = _progress_pct(iteration) + logger.info( + "----- EntityGenerator iteration %d/%d — %d messages, mapping=%s -----", + current_iteration, + iteration_limit, + len(messages), + "set" if len(ctx.entity_mappings) > pre_run_mapping_count else "unset", + ) + notify( + f"Mapping iteration {current_iteration}/{iteration_limit}…", + pct=pct, + ) + + t0 = time.time() + try: + llm_response = call_serving_endpoint( + host, + token, + endpoint_name, + messages, + tools=TOOL_DEFINITIONS, + max_tokens=_MAX_TOKENS, + temperature=0.1, + timeout=LLM_TIMEOUT, + trace_name=_TRACE_NAME, + ) + except requests.exceptions.HTTPError as exc: + status = exc.response.status_code if exc.response is not None else "?" + logger.warning( + "EntityGenerator iteration %d: HTTPError status=%s", + current_iteration, + status, + ) + logger.debug( + "EntityGenerator iteration %d: HTTPError body: %.500s", + current_iteration, + exc.response.text if exc.response is not None else "N/A", + ) + if exc.response is not None and status in (400, 422): + result.error = "LLM endpoint does not support function calling" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "EntityGenerator: endpoint refused tools — cannot produce a mapping" + ) + return result + result.error = f"LLM request failed: {exc}" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "EntityGenerator: LLM request failed at iteration %d: %s", + current_iteration, + exc, + ) + return result + except requests.exceptions.ReadTimeout: + result.error = f"LLM request timed out after {LLM_TIMEOUT}s" + result.iterations = current_iteration + result.usage = total_usage + logger.error("EntityGenerator: timeout at iteration %d", current_iteration) + return result + except requests.exceptions.RequestException as exc: + result.error = f"LLM request failed: {exc}" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "EntityGenerator: request exception at iteration %d: %s", + current_iteration, + exc, + ) + return result + + elapsed_ms = int((time.time() - t0) * 1000) + logger.info( + "EntityGenerator iteration %d: LLM responded in %dms", + current_iteration, + elapsed_ms, + ) + + accumulate_usage(total_usage, llm_response.get("usage", {})) + + choice = llm_response.get("choices", [{}])[0] + finish_reason = choice.get("finish_reason", "?") + message = choice.get("message", {}) + tool_calls = message.get("tool_calls", []) + has_content = bool(message.get("content")) + logger.info( + "EntityGenerator iteration %d: finish_reason=%s, tool_calls=%d, has_content=%s", + current_iteration, + finish_reason, + len(tool_calls), + has_content, + ) + + if not tool_calls: + # The Generator must terminate via submit_entity_mapping, never + # via free text. + content = (message.get("content") or "")[:500] + logger.warning( + "EntityGenerator iteration %d: produced text without submitting mapping — %d chars", + current_iteration, + len(message.get("content") or ""), + ) + result.steps.append( + EntityGenStep( + step_type="output", + content=content, + duration_ms=elapsed_ms, + ) + ) + result.error = "entity generator produced text without submitting mapping" + result.iterations = current_iteration + result.usage = total_usage + notify( + "Entity generator produced text without submitting mapping.", + pct=pct, + ) + return result + + logger.info( + "EntityGenerator iteration %d: processing %d tool call(s): [%s]", + current_iteration, + len(tool_calls), + ", ".join( + tc.get("function", {}).get("name", "?") for tc in tool_calls + ), + ) + messages.append(message) + + terminal_success = False + for tc_idx, tc in enumerate(tool_calls, 1): + func = tc.get("function", {}) + tool_name = func.get("name", "") + raw_args = func.get("arguments", "{}") + tool_id = tc.get("id", "") + + try: + arguments = json.loads(raw_args) + except json.JSONDecodeError: + arguments = {} + + logger.info( + "EntityGenerator iteration %d: calling tool '%s' (%d/%d)", + current_iteration, + tool_name, + tc_idx, + len(tool_calls), + ) + + # Human-readable progress messages per tool. + if tool_name == "submit_entity_mapping": + notify(f"Submitting mapping for {class_label or class_uri}…", pct=pct) + elif tool_name == "sample_table": + fn = arguments.get("full_name", "?") + notify(f"Sampling {fn}…", pct=pct) + elif tool_name == "execute_sql": + sql_preview = arguments.get("sql", "")[:80] + notify(f"Running SQL: {sql_preview}…", pct=pct) + else: + notify(f"Calling {tool_name}…", pct=pct) + + result.steps.append( + EntityGenStep( + step_type="tool_call", + content=json.dumps(arguments, default=str)[:500], + tool_name=tool_name, + ) + ) + + t1 = time.time() + tool_result = dispatch_tool( + TOOL_HANDLERS, ctx, tool_name, arguments, trace_name=_TRACE_NAME + ) + tool_ms = int((time.time() - t1) * 1000) + + logger.info( + "EntityGenerator iteration %d: tool '%s' returned %d chars in %dms", + current_iteration, + tool_name, + len(tool_result), + tool_ms, + ) + + result.steps.append( + EntityGenStep( + step_type="tool_result", + content=( + (tool_result[:500] + "…") + if len(tool_result) > 500 + else tool_result + ), + tool_name=tool_name, + duration_ms=tool_ms, + ) + ) + + messages.append( + { + "role": "tool", + "tool_call_id": tool_id, + "content": tool_result, + } + ) + + # Detect terminal success: submit_entity_mapping returned + # success=True AND a mapping for THIS class_uri is present in + # ctx.entity_mappings. A submit with a mismatched class_uri (the + # LLM mapped a different class than requested) is NOT terminal — + # we coach the LLM via a corrective tool message and let the loop + # continue so it can resubmit with the right URI. + if tool_name == "submit_entity_mapping": + try: + parsed = json.loads(tool_result) + except json.JSONDecodeError: + parsed = {} + if parsed.get("success") is True: + matched = any( + m.get("ontology_class") == class_uri + for m in ctx.entity_mappings + ) + if matched: + terminal_success = True + logger.info( + "EntityGenerator iteration %d: submit_entity_mapping succeeded — terminating", + current_iteration, + ) + else: + submitted_uri = arguments.get("class_uri", "") + mismatch_msg = ( + f"submitted class_uri '{submitted_uri}' does not " + f"match requested class_uri '{class_uri}'; " + f"resubmit with class_uri='{class_uri}'" + ) + logger.warning( + "EntityGenerator iteration %d: submit_entity_mapping " + "class_uri mismatch — submitted=%s, requested=%s", + current_iteration, + submitted_uri, + class_uri, + ) + corrective_payload = json.dumps( + {"success": False, "error": mismatch_msg} + ) + # Replace the recorded tool_result step's content so + # the UI / trace reflects the corrective signal + # rather than the original (misleading) success + # response. + result.steps[-1] = EntityGenStep( + step_type="tool_result", + content=corrective_payload, + tool_name=tool_name, + duration_ms=result.steps[-1].duration_ms, + ) + # Replace the tool message just appended to + # ``messages`` so the LLM sees the corrective + # payload on the next turn (one tool message per + # tool_call_id — keep the protocol clean). + messages[-1] = { + "role": "tool", + "tool_call_id": tool_id, + "content": corrective_payload, + } + + if terminal_success: + # Pull the mapping for this class by strict URI match. The + # terminal-success guard above already verified an entry with + # this URI exists; if we somehow can't find one here that's an + # internal invariant violation, not a recoverable failure. + submitted = next( + ( + m + for m in reversed(ctx.entity_mappings) + if m.get("ontology_class") == class_uri + ), + None, + ) + if submitted is None: + result.error = ( + "internal: submit succeeded but mapping not found for class_uri" + ) + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "===== ENTITY GENERATOR FAILED ===== %s (class=%s)", + result.error, + class_uri, + ) + return result + result.success = True + result.mapping = submitted + result.iterations = current_iteration + result.usage = total_usage + logger.info( + "===== ENTITY GENERATOR COMPLETE ===== class=%s, iterations=%d, " + "prompt_tokens=%d, completion_tokens=%d", + class_uri, + result.iterations, + total_usage["prompt_tokens"], + total_usage["completion_tokens"], + ) + notify(f"Mapping for {class_label or class_uri} complete!", pct=100) + return result + + # Budget exhausted without a successful submit. + result.iterations = iteration_limit + result.usage = total_usage + result.error = "entity generator exhausted iteration budget" + logger.error("===== ENTITY GENERATOR FAILED ===== %s", result.error) + notify(result.error, pct=95) + return result diff --git a/src/agents/agent_mapping_pge/generators/relationship.py b/src/agents/agent_mapping_pge/generators/relationship.py new file mode 100644 index 00000000..af5347f0 --- /dev/null +++ b/src/agents/agent_mapping_pge/generators/relationship.py @@ -0,0 +1,875 @@ +""" +OntoBricks Mapping-PGE RelationshipGenerator Agent. + +Sprint 5 of the Planner-Generator-Evaluator (PGE) redesign. + +The RelationshipGenerator is the sibling of :mod:`.entity` — same ReAct +loop shape and tooling discipline, narrower scope. It maps **one** ontology +property (relationship) at a time, given: + +* the property to map (uri, label, comment, domain, range), +* the source and target **entity mappings already produced by the + EntityGenerator** — crucially, the ``id_column`` each side mapped on, and +* a small SourceModel slice that surfaces the relevant join-key subgraph. + +The system prompt FORBIDS picking endpoint columns that do not match the +already-mapped entity IDs: the source/target endpoint columns are GIVEN. +This keeps relationships consistent with the entities they connect — if a +relationship's ``source_id`` doesn't match the source entity's ``id_column``, +the resulting SPARQL graph cannot join. + +The loop semantics mirror :mod:`.entity`: + +* Same default budget (12). +* Same 3-second inter-iteration delay. +* Same MLflow trace decorator. +* No single-shot fallback (terminate via tool call only). +* Strict ``property_uri`` match on terminal detection — a submit with the + wrong URI is coached via a corrective tool message, not accepted. +""" + +import json +import time +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional + +import requests + +from back.core.logging import get_logger +from agents.engine_base import ( + call_serving_endpoint, + dispatch_tool, + accumulate_usage, +) +from agents.tools.context import ToolContext +from agents.tools.mapping import ( + MAPPING_TOOL_DEFINITIONS_BY_NAME, + MAPPING_TOOL_HANDLERS, +) +from agents.tools.planner import ( + SAMPLE_TABLE_DEF, + tool_sample_table, +) +from agents.tools.sql import ( + SQL_TOOL_DEFINITIONS, + SQL_TOOL_HANDLERS, +) +from agents.tracing import trace_agent + +logger = get_logger(__name__) + +MAX_ITERATIONS = 12 +LLM_TIMEOUT = 180 +_ITERATION_DELAY_SEC = 1 +# See planner._MAX_TOKENS comment — large UNION ALL queries for cross-source +# relationships can exceed a small ceiling. +_MAX_TOKENS = 50000 + +_TRACE_NAME = "mapping_pge_relationship_generator" + + +# ===================================================== +# Tool aggregation +# ===================================================== +# +# The RelationshipGenerator only needs: +# * execute_sql – validate the composed two-column SELECT. +# * sample_table – peek at endpoint columns when the join is +# ambiguous (rare; usually unnecessary). +# * submit_relationship_mapping – TERMINAL. +# +# We deliberately exclude: +# * get_ontology / get_metadata / get_documents_context — wrong stage. +# * column_value_overlap / distinct_count — already locked by the Planner. +# * submit_source_model / submit_entity_mapping — wrong stage. + +_SUBMIT_RELATIONSHIP_DEF: dict = MAPPING_TOOL_DEFINITIONS_BY_NAME[ + "submit_relationship_mapping" +] + +TOOL_DEFINITIONS: List[dict] = ( + SQL_TOOL_DEFINITIONS + + [SAMPLE_TABLE_DEF] + + [_SUBMIT_RELATIONSHIP_DEF] +) + +TOOL_HANDLERS: Dict[str, Callable] = { + **SQL_TOOL_HANDLERS, + "sample_table": tool_sample_table, + "submit_relationship_mapping": MAPPING_TOOL_HANDLERS[ + "submit_relationship_mapping" + ], +} + + +# ===================================================== +# Data classes +# ===================================================== + + +@dataclass +class RelationshipGenStep: + """One observable step of the RelationshipGenerator's execution. + + Mirrors :class:`.entity.EntityGenStep` — scoped to the relationship + generator so the orchestrator (Sprint 7) can render a per-property + timeline in the UI. + """ + + step_type: str # "tool_call" | "tool_result" | "output" + content: str + tool_name: str = "" + duration_ms: int = 0 + + +@dataclass +class RelationshipGenResult: + """Outcome of a single RelationshipGenerator invocation. + + ``mapping`` holds the submitted relationship-mapping dict (the same + shape the handler appends to ``ctx.relationships``) when ``success`` is + True. + """ + + success: bool + mapping: Optional[dict] = None + steps: List[RelationshipGenStep] = field(default_factory=list) + iterations: int = 0 + error: str = "" + usage: Dict[str, int] = field(default_factory=dict) + + +# ===================================================== +# System prompt +# ===================================================== +# +# The RELATIONSHIP SQL RULES section is lifted verbatim from the legacy +# in-house mapping agent (the section starting "SQL RULES FOR +# RELATIONSHIPS"). To those rules we add the Sprint 5 constraints: the +# source and target ID columns are GIVEN by the already-produced entity +# mappings; the LLM may not pick different endpoint columns. + +SYSTEM_PROMPT = """\ +You are a senior data engineer. Your job is to map ONE ontology property \ +(relationship) to a single SQL SELECT query against Databricks source \ +table(s), validated against real data via execute_sql, and submitted via \ +submit_relationship_mapping. + +YOU WILL BE GIVEN +• ontology_property: the property to map (uri, label, comment, domain, range). +• source_entity_mapping / target_entity_mapping: the ALREADY-MAPPED endpoint \ +entities — each with its class_uri, id_column, and the exact SQL it ran. \ +READ BOTH SQLs: they are the source of truth for your endpoint values. +• source_model_slice: relevant_joins[] {from_ref, to_ref, confidence, \ +overlap_pct, kind} and candidate_tables[] the Planner curated. Prefer \ +high-overlap, high-confidence joins. + +THE EDGE MUST CONNECT EXISTING NODES +An edge row is (source_id, target_id). Each value MUST already exist as a \ +node id in the corresponding entity, or it "dangles" and the mapping is \ +rejected (the evaluator fails any mapping with >5% dangling on either side, \ +unless the Planner predicted a cross-source band). Three traps cause almost \ +all dangling — avoid all three: + +TRAP 1 — id_column is an ALIAS FOR A DERIVED EXPRESSION, not a real column. +Each entity mints its id with ``SELECT AS `` (the \ +id_column is usually just ``ID``). That expression is often a canonical-key \ +normalization, e.g.:: + + CONCAT(regexp_extract(EPISODE_ID, '([a-f0-9][a-f0-9-]+-preg-[0-9]+)', 1), '-baby') + +There is no ``ID`` column to select. You MUST **reproduce the entity's id \ +EXPRESSION verbatim** (copied from its SQL), applied to your table, for the \ +endpoint. A raw column (a ``*_id`` join key, a trust-local id) will NOT match \ +→ 100% dangling. + +TRAP 2 — building from a table only ONE endpoint entity covers. +The two entities may be sourced from different trusts (compare the FROM \ +tables in each entity's SQL). Their id universes overlap only on the trust(s) \ +BOTH cover. Build the edge from a table present in BOTH entities' FROM lists \ +(the shared-coverage table). Building from a table only the target covers \ +makes every source_id absent from the source → 100% source-dangling (and \ +vice-versa). + +TRAP 3 — column-name / alias mismatch on submit. +Your SELECT MUST alias the two columns exactly ``AS source_id`` and \ +``AS target_id``, and you MUST submit ``source_id_column="source_id"`` and \ +``target_id_column="target_id"``. These name the columns IN YOUR EDGE OUTPUT, \ +NOT the entity's id_column. If they disagree with your SELECT aliases the \ +evaluator reads nothing and every edge dangles. + +WORKED EXAMPLE — ``Baby --hasApgarScore--> Apgar Score`` +Baby is sourced from {trust_a.maternity_episode, trust_b.delivery}; Apgar \ +Score from {trust_a.maternity_episode, trust_c.maternity_event}. Shared \ +coverage = trust_a only (Trap 2). Both ids share the canonical pregnancy core \ +with role suffixes (Trap 1). So build from trust_a, reproducing both \ +expressions from one row:: + + SELECT CONCAT(regexp_extract(EPISODE_ID, '([a-f0-9][a-f0-9-]+-preg-[0-9]+)', 1), '-baby') AS source_id, + CONCAT(regexp_extract(EPISODE_ID, '([a-f0-9][a-f0-9-]+-preg-[0-9]+)', 1), '-apgar') AS target_id + FROM fiifi_cdm_demo_catalog.trust_a.maternity_episode + WHERE regexp_extract(EPISODE_ID, '([a-f0-9][a-f0-9-]+-preg-[0-9]+)', 1) <> '' + +Building from trust_c (Apgar's natural home) would dangle 100% on the Baby \ +side, because Baby has no trust_c rows. + +TOOLS + • execute_sql – validate / probe your SELECT (runs with a \ +small LIMIT; the persisted mapping has none). + • sample_table – peek at real values when a column is \ +ambiguous. + • submit_relationship_mapping – TERMINAL. Call EXACTLY ONCE, only after a \ +clean dangling probe (see WORKFLOW step 4). + +SQL RULES +• SELECT exactly two columns: `` AS source_id, AS target_id`` (Trap 1 + Trap 3). +• Build FROM a table both entities cover (Trap 2). Same-trust FK joins: one \ +table, no join. Cross-source: a UNION ALL of per-source SELECTs (each source \ +that holds both cores), or a JOIN on the shared canonical key. +• No LIMIT, no ORDER BY. Always full table names (catalog.schema.table). + +WORKFLOW +1. Read BOTH entity SQLs. Extract each entity's id EXPRESSION (the \ +``SELECT AS ``) and its set of FROM tables. +2. Pick a shared-coverage table (Trap 2). Compose the two-column SELECT, \ +setting source_id to the source entity's id EXPRESSION and target_id to the \ +target entity's id EXPRESSION, reproduced verbatim and aliased ``AS \ +source_id`` / ``AS target_id``. +3. Call execute_sql to confirm the query parses and returns two columns of \ +rows. Read any error and fix it; never submit an un-validated query. +4. SELF-VERIFY THE VALUES BEFORE SUBMITTING (MANDATORY GATE). Run this probe \ +via execute_sql: + + WITH rel AS (), + src AS (), + tgt AS () + SELECT + (SELECT COUNT(*) FROM rel) AS edges, + (SELECT COUNT(*) FROM rel r WHERE r.source_id NOT IN (SELECT ID FROM src)) AS dangling_src, + (SELECT COUNT(*) FROM rel r WHERE r.target_id NOT IN (SELECT ID FROM tgt)) AS dangling_tgt + + You may submit ONLY when ``dangling_src`` AND ``dangling_tgt`` are both 0 \ +(or a tiny fraction of edges). If either is high you hit Trap 1 or Trap 2 — \ +fix the endpoint expression or switch to the shared-coverage table, then \ +re-run this probe. Do NOT submit on an unrun or failing probe. +5. submit_relationship_mapping EXACTLY ONCE: property_uri, property_name, \ +sql_query (no LIMIT), source_id_column="source_id", target_id_column=\ +"target_id", domain, range_class. +6. Terminal — emit no free text after submitting. + +GENERAL RULES +• Only ever pass row-returning queries (SELECT / WITH …) to execute_sql. +• Do not call get_metadata, get_ontology, column_value_overlap, \ +distinct_count, submit_entity_mapping, or submit_source_model — they are \ +not available to you. The slice plus the entity mappings carry everything \ +you need. +• If a retry_hint is present at the top of the user message, treat it as \ +authoritative — your previous attempt failed for the reason stated; do NOT \ +repeat the same mistake. +""" + + +# ===================================================== +# Internal helpers +# ===================================================== + + +def _summarise_entity_mapping(em: dict, side: str) -> List[str]: + """One-block textual summary of a previously-produced entity mapping. + + Surfaces exactly the fields the LLM needs to constrain its endpoint + choice: the class_uri, the id_column it locked in, and the SQL it ran. + Anything else (label_column, attribute_mappings, …) is irrelevant to the + relationship task and is intentionally omitted to keep the prompt tight. + """ + em = em or {} + class_uri = ( + em.get("ontology_class") or em.get("class_uri") or em.get("class") or "" + ) + id_column = em.get("id_column", "") + sql_query = em.get("sql_query", "") + return [ + f"{side.upper()} ENTITY MAPPING", + f" class_uri: {class_uri}", + f" id_column: {id_column}", + f" sql: {sql_query}", + ] + + +def _format_join(j: dict) -> str: + """Readable one-line rendering of a join entry from the slice. + + Defensive about missing fields — partial joins still render usefully so + a malformed slice doesn't blow up the prompt build. + """ + from_ref = j.get("from_ref", "?") + to_ref = j.get("to_ref", "?") + kind = j.get("kind", "?") + conf = j.get("confidence") + overlap = j.get("overlap_pct") + extras: List[str] = [] + if conf is not None: + try: + extras.append(f"confidence={float(conf):.2f}") + except (TypeError, ValueError): + extras.append(f"confidence={conf}") + if overlap is not None: + try: + extras.append(f"overlap_pct={float(overlap):.2f}") + except (TypeError, ValueError): + extras.append(f"overlap_pct={overlap}") + suffix = (" — " + ", ".join(extras)) if extras else "" + return f" - {from_ref} -> {to_ref} [{kind}]{suffix}" + + +def _build_user_prompt( + ontology_property: dict, + source_entity_mapping: dict, + target_entity_mapping: dict, + source_model_slice: dict, + retry_hint: Optional[str] = None, +) -> str: + """Render the per-property user prompt. + + Structure: + 1. retry_hint (if any) at the very top + 2. ontology property metadata + 3. source entity mapping summary (class_uri / id_column / sql) + 4. target entity mapping summary + 5. relevant joins (one line per join, readable) + 6. candidate_tables (raw JSON — small) + 7. a reminder block reiterating the two-column / endpoint-match rules + """ + parts: List[str] = [] + + if retry_hint: + parts.append("RETRY HINT (authoritative — your previous attempt FAILED):") + parts.append(retry_hint) + parts.append( + "DO NOT repeat the same column choice. If the hint mentions " + "'dangling' or 'canonical id': sample BOTH the candidate endpoint " + "column AND the entity's id_column, compare actual values, and " + "pick the column whose values overlap. Run the dangling-edge " + "probe (step 4 of WORKFLOW) BEFORE submitting this time.\n" + ) + + prop_uri = ontology_property.get("uri", "") + prop_label = ( + ontology_property.get("label") or ontology_property.get("name", "") + ) + prop_comment = ontology_property.get("comment", "") or "" + prop_domain = ontology_property.get("domain", "") or "" + prop_range = ontology_property.get("range", "") or "" + + parts.append("ONTOLOGY PROPERTY") + parts.append(f" uri: {prop_uri}") + parts.append(f" label: {prop_label}") + if prop_comment: + parts.append(f" comment: {prop_comment}") + parts.append(f" domain: {prop_domain}") + parts.append(f" range: {prop_range}") + + parts.append("") + parts.extend(_summarise_entity_mapping(source_entity_mapping, side="source")) + + parts.append("") + parts.extend(_summarise_entity_mapping(target_entity_mapping, side="target")) + + slice_obj = source_model_slice or {} + joins = slice_obj.get("relevant_joins") or [] + candidates = slice_obj.get("candidate_tables") or [] + + parts.append("") + parts.append("RELEVANT JOINS") + if joins: + for j in joins: + parts.append(_format_join(j)) + else: + parts.append(" (none surfaced by the Planner — fall back to a single-table SELECT if possible)") + + if candidates: + parts.append("") + parts.append("CANDIDATE TABLES") + parts.append(json.dumps(candidates, indent=2, default=str)) + + src_id = (source_entity_mapping or {}).get("id_column", "") + tgt_id = (target_entity_mapping or {}).get("id_column", "") + + parts.append("") + parts.append("REMINDERS (CRITICAL)") + parts.append( + " • The persisted SQL MUST return EXACTLY two columns aliased " + "AS source_id and AS target_id." + ) + parts.append( + f" • source_id values MUST come from the column '{src_id}' (the " + "source entity's id_column) — or be directly transformable into it " + "via a join key in the slice." + ) + parts.append( + f" • target_id values MUST come from the column '{tgt_id}' (the " + "target entity's id_column) — same constraint." + ) + parts.append( + " • Validate with execute_sql, then call submit_relationship_mapping " + "exactly once." + ) + + prompt = "\n".join(parts) + logger.debug( + "_build_user_prompt for property=%s (%d chars):\n%s", + prop_uri, + len(prompt), + prompt, + ) + return prompt + + +# ===================================================== +# Public entry point +# ===================================================== + + +@trace_agent(name="mapping_pge_relationship_generator") +def run_relationship_generator( + host: str, + token: str, + endpoint_name: str, + client: Any, + *, + ontology_property: dict, + source_entity_mapping: dict, + target_entity_mapping: dict, + source_model_slice: dict, + retry_hint: Optional[str] = None, + on_step: Optional[Callable[[str, int], None]] = None, + max_iterations: int = MAX_ITERATIONS, +) -> RelationshipGenResult: + """Run the RelationshipGenerator agent for a single ontology property. + + The agent composes a two-column SQL SELECT (``source_id`` / ``target_id``) + that realises the relationship between the source and target entities + using the join-key subgraph in ``source_model_slice``, validates the + SQL via ``execute_sql``, and submits the validated mapping via the + terminal ``submit_relationship_mapping`` tool. + + Args: + host: Databricks workspace URL. + token: Bearer token for the serving endpoint. + endpoint_name: Foundation Model serving endpoint name. + client: Databricks SQL client (must expose ``execute_query(sql)``). + ontology_property: Full dict for the SINGLE property to map (uri, + label, comment, domain, range). + source_entity_mapping: The ALREADY-MAPPED source entity (carries the + ``id_column`` the source endpoint must align with). + target_entity_mapping: The ALREADY-MAPPED target entity (same). + source_model_slice: Filtered SourceModel slice with relevant_joins + and optional candidate_tables. + retry_hint: Optional one-sentence hint from the orchestrator's + previous-attempt evaluation. When present, surfaced at the top + of the user prompt. + on_step: Optional progress callback ``(msg, pct)`` for UI updates. + max_iterations: Upper bound on tool-call iterations (default 12 — + same as the EntityGenerator). + + Returns: + A :class:`RelationshipGenResult`. ``success`` is True iff a mapping + was successfully submitted with the requested ``property_uri``; in + that case ``mapping`` holds the submitted dict. On failure, ``error`` + explains why and ``mapping`` is None. + """ + iteration_limit = max_iterations if max_iterations is not None else MAX_ITERATIONS + + property_uri = (ontology_property or {}).get("uri", "") + property_label = ( + (ontology_property or {}).get("label") + or (ontology_property or {}).get("name", "") + ) + n_joins = len(((source_model_slice or {}).get("relevant_joins") or [])) + n_candidates = len(((source_model_slice or {}).get("candidate_tables") or [])) + + logger.info( + "===== RELATIONSHIP GENERATOR START ===== endpoint=%s, property=%s (%s), " + "joins=%d, candidate_tables=%d, retry_hint=%s, max_iter=%d", + endpoint_name, + property_label, + property_uri, + n_joins, + n_candidates, + "yes" if retry_hint else "no", + iteration_limit, + ) + + ctx = ToolContext( + host=host.rstrip("/"), + token=token, + client=client, + # The slice + entity mappings subsume metadata/ontology for this + # agent; the unified ToolContext still wants these fields, so we + # leave them empty. + metadata={}, + ontology={}, + documents=[], + ) + + result = RelationshipGenResult(success=False) + + user_prompt = _build_user_prompt( + ontology_property or {}, + source_entity_mapping or {}, + target_entity_mapping or {}, + source_model_slice or {}, + retry_hint=retry_hint, + ) + messages: List[dict] = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ] + logger.info( + "RelationshipGenerator conversation initialized: system=%d chars, user=%d chars", + len(SYSTEM_PROMPT), + len(user_prompt), + ) + + total_usage: Dict[str, int] = {"prompt_tokens": 0, "completion_tokens": 0} + + def _progress_pct(iteration_idx: int) -> int: + ratio = (iteration_idx + 1) / max(iteration_limit, 1) + return min(5 + int(ratio * 90), 95) + + def notify(msg: str, *, pct: Optional[int] = None) -> None: + actual_pct = pct if pct is not None else 5 + logger.info("RELATIONSHIP GEN STEP [%d%%] %s", actual_pct, msg) + if on_step: + on_step(msg, actual_pct) + + notify(f"Generating mapping for {property_label or property_uri}…", pct=1) + + # Snapshot the pre-existing relationship count so we can detect "this + # run added a mapping" without relying on absolute counters. Future-proof + # for an orchestrator that reuses a ToolContext across calls. + pre_run_count = len(ctx.relationships) + + # ------------------------------------------------------------------ + # Agent loop + # ------------------------------------------------------------------ + for iteration in range(iteration_limit): + if iteration > 0: + logger.debug( + "Iteration %d: waiting %ds before LLM call (rate limit mitigation)", + iteration + 1, + _ITERATION_DELAY_SEC, + ) + time.sleep(_ITERATION_DELAY_SEC) + + current_iteration = iteration + 1 + pct = _progress_pct(iteration) + logger.info( + "----- RelationshipGenerator iteration %d/%d — %d messages, mapping=%s -----", + current_iteration, + iteration_limit, + len(messages), + "set" if len(ctx.relationships) > pre_run_count else "unset", + ) + notify( + f"Mapping iteration {current_iteration}/{iteration_limit}…", + pct=pct, + ) + + t0 = time.time() + try: + llm_response = call_serving_endpoint( + host, + token, + endpoint_name, + messages, + tools=TOOL_DEFINITIONS, + max_tokens=_MAX_TOKENS, + temperature=0.1, + timeout=LLM_TIMEOUT, + trace_name=_TRACE_NAME, + ) + except requests.exceptions.HTTPError as exc: + status = exc.response.status_code if exc.response is not None else "?" + logger.warning( + "RelationshipGenerator iteration %d: HTTPError status=%s", + current_iteration, + status, + ) + logger.debug( + "RelationshipGenerator iteration %d: HTTPError body: %.500s", + current_iteration, + exc.response.text if exc.response is not None else "N/A", + ) + if exc.response is not None and status in (400, 422): + result.error = "LLM endpoint does not support function calling" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "RelationshipGenerator: endpoint refused tools — cannot produce a mapping" + ) + return result + result.error = f"LLM request failed: {exc}" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "RelationshipGenerator: LLM request failed at iteration %d: %s", + current_iteration, + exc, + ) + return result + except requests.exceptions.ReadTimeout: + result.error = f"LLM request timed out after {LLM_TIMEOUT}s" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "RelationshipGenerator: timeout at iteration %d", current_iteration + ) + return result + except requests.exceptions.RequestException as exc: + result.error = f"LLM request failed: {exc}" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "RelationshipGenerator: request exception at iteration %d: %s", + current_iteration, + exc, + ) + return result + + elapsed_ms = int((time.time() - t0) * 1000) + logger.info( + "RelationshipGenerator iteration %d: LLM responded in %dms", + current_iteration, + elapsed_ms, + ) + + accumulate_usage(total_usage, llm_response.get("usage", {})) + + choice = llm_response.get("choices", [{}])[0] + finish_reason = choice.get("finish_reason", "?") + message = choice.get("message", {}) + tool_calls = message.get("tool_calls", []) + has_content = bool(message.get("content")) + logger.info( + "RelationshipGenerator iteration %d: finish_reason=%s, tool_calls=%d, has_content=%s", + current_iteration, + finish_reason, + len(tool_calls), + has_content, + ) + + if not tool_calls: + # The Generator must terminate via submit_relationship_mapping, + # never via free text. + content = (message.get("content") or "")[:500] + logger.warning( + "RelationshipGenerator iteration %d: produced text without submitting mapping — %d chars", + current_iteration, + len(message.get("content") or ""), + ) + result.steps.append( + RelationshipGenStep( + step_type="output", + content=content, + duration_ms=elapsed_ms, + ) + ) + result.error = "relationship generator produced text without submitting mapping" + result.iterations = current_iteration + result.usage = total_usage + notify( + "Relationship generator produced text without submitting mapping.", + pct=pct, + ) + return result + + logger.info( + "RelationshipGenerator iteration %d: processing %d tool call(s): [%s]", + current_iteration, + len(tool_calls), + ", ".join( + tc.get("function", {}).get("name", "?") for tc in tool_calls + ), + ) + messages.append(message) + + terminal_success = False + for tc_idx, tc in enumerate(tool_calls, 1): + func = tc.get("function", {}) + tool_name = func.get("name", "") + raw_args = func.get("arguments", "{}") + tool_id = tc.get("id", "") + + try: + arguments = json.loads(raw_args) + except json.JSONDecodeError: + arguments = {} + + logger.info( + "RelationshipGenerator iteration %d: calling tool '%s' (%d/%d)", + current_iteration, + tool_name, + tc_idx, + len(tool_calls), + ) + + if tool_name == "submit_relationship_mapping": + notify( + f"Submitting mapping for {property_label or property_uri}…", + pct=pct, + ) + elif tool_name == "sample_table": + fn = arguments.get("full_name", "?") + notify(f"Sampling {fn}…", pct=pct) + elif tool_name == "execute_sql": + sql_preview = arguments.get("sql", "")[:80] + notify(f"Running SQL: {sql_preview}…", pct=pct) + else: + notify(f"Calling {tool_name}…", pct=pct) + + result.steps.append( + RelationshipGenStep( + step_type="tool_call", + content=json.dumps(arguments, default=str)[:500], + tool_name=tool_name, + ) + ) + + t1 = time.time() + tool_result = dispatch_tool( + TOOL_HANDLERS, ctx, tool_name, arguments, trace_name=_TRACE_NAME + ) + tool_ms = int((time.time() - t1) * 1000) + + logger.info( + "RelationshipGenerator iteration %d: tool '%s' returned %d chars in %dms", + current_iteration, + tool_name, + len(tool_result), + tool_ms, + ) + + result.steps.append( + RelationshipGenStep( + step_type="tool_result", + content=( + (tool_result[:500] + "…") + if len(tool_result) > 500 + else tool_result + ), + tool_name=tool_name, + duration_ms=tool_ms, + ) + ) + + messages.append( + { + "role": "tool", + "tool_call_id": tool_id, + "content": tool_result, + } + ) + + # Detect terminal success: submit_relationship_mapping returned + # success=True AND a mapping for THIS property_uri is present in + # ctx.relationships. A submit with a mismatched property_uri is + # NOT terminal — we coach the LLM via a corrective tool message + # and let the loop continue. + if tool_name == "submit_relationship_mapping": + try: + parsed = json.loads(tool_result) + except json.JSONDecodeError: + parsed = {} + if parsed.get("success") is True: + matched = any( + m.get("property") == property_uri + for m in ctx.relationships + ) + if matched: + terminal_success = True + logger.info( + "RelationshipGenerator iteration %d: submit_relationship_mapping succeeded — terminating", + current_iteration, + ) + else: + submitted_uri = arguments.get("property_uri", "") + mismatch_msg = ( + f"submitted property_uri '{submitted_uri}' does " + f"not match requested property_uri " + f"'{property_uri}'; resubmit with " + f"property_uri='{property_uri}'" + ) + logger.warning( + "RelationshipGenerator iteration %d: submit_relationship_mapping " + "property_uri mismatch — submitted=%s, requested=%s", + current_iteration, + submitted_uri, + property_uri, + ) + corrective_payload = json.dumps( + {"success": False, "error": mismatch_msg} + ) + # Replace the recorded tool_result step's content so + # the UI / trace shows the corrective signal. + result.steps[-1] = RelationshipGenStep( + step_type="tool_result", + content=corrective_payload, + tool_name=tool_name, + duration_ms=result.steps[-1].duration_ms, + ) + # Replace the tool message on the conversation so + # the LLM sees the corrective payload next turn. + messages[-1] = { + "role": "tool", + "tool_call_id": tool_id, + "content": corrective_payload, + } + + if terminal_success: + # Pull the mapping for this property by strict URI match. + submitted = next( + ( + m + for m in reversed(ctx.relationships) + if m.get("property") == property_uri + ), + None, + ) + if submitted is None: + result.error = ( + "internal: submit succeeded but mapping not found for property_uri" + ) + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "===== RELATIONSHIP GENERATOR FAILED ===== %s (property=%s)", + result.error, + property_uri, + ) + return result + result.success = True + result.mapping = submitted + result.iterations = current_iteration + result.usage = total_usage + logger.info( + "===== RELATIONSHIP GENERATOR COMPLETE ===== property=%s, iterations=%d, " + "prompt_tokens=%d, completion_tokens=%d", + property_uri, + result.iterations, + total_usage["prompt_tokens"], + total_usage["completion_tokens"], + ) + notify( + f"Mapping for {property_label or property_uri} complete!", pct=100 + ) + return result + + # Budget exhausted without a successful submit. + result.iterations = iteration_limit + result.usage = total_usage + result.error = "relationship generator exhausted iteration budget" + logger.error("===== RELATIONSHIP GENERATOR FAILED ===== %s", result.error) + notify(result.error, pct=95) + return result diff --git a/src/agents/agent_mapping_pge/planner.py b/src/agents/agent_mapping_pge/planner.py new file mode 100644 index 00000000..00403630 --- /dev/null +++ b/src/agents/agent_mapping_pge/planner.py @@ -0,0 +1,729 @@ +""" +OntoBricks Mapping-PGE Planner Agent. + +Sprint 3 of the Planner-Generator-Evaluator (PGE) redesign. + +The Planner is a single-invocation agent (no internal retry loop — re- +invocations come from the orchestrator on Evaluator escalation in Sprint 7). +It consumes the ontology, table metadata, and any imported domain documents, +probes the source data via the planner tools (sample_table, column_value_overlap, +distinct_count) plus the shared tools (get_metadata, get_ontology, +get_documents_context, execute_sql), and emits a validated +:class:`SourceModel` via the ``submit_source_model`` terminal tool. + +The loop semantics mirror the prior single-loop mapping agent — same +``call_serving_endpoint`` + ``dispatch_tool`` ReAct cycle, same 3-second +inter-iteration delay, same accumulated usage tracking, same MLflow trace +decorator — with two key differences: + +* No fallback to single-shot generation. If the endpoint refuses tools, the + Planner returns failure (the Planner *needs* tools — it produces structured + output through ``submit_source_model``). +* Smaller default iteration budget (25 instead of 60) — the Planner is more + focused than the auto-mapping agent. +""" + +import json +import time +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional + +import requests + +if TYPE_CHECKING: + from agents.agent_mapping_pge.contracts import SourceModel + +from back.core.logging import get_logger +from agents.engine_base import ( + call_serving_endpoint, + dispatch_tool, + accumulate_usage, +) +from agents.tools.context import ToolContext +from agents.tools.documents import ( + GET_DOCUMENTS_CONTEXT_DEF, + tool_get_documents_context, +) +from agents.tools.metadata import ( + GET_METADATA_DEF, + tool_get_metadata, +) +from agents.tools.ontology import ( + ONTOLOGY_TOOL_DEFINITIONS, + ONTOLOGY_TOOL_HANDLERS, +) +from agents.tools.planner import ( + PLANNER_TOOL_DEFINITIONS, + PLANNER_TOOL_HANDLERS, +) +from agents.tools.sql import ( + SQL_TOOL_DEFINITIONS, + SQL_TOOL_HANDLERS, +) +from agents.tracing import trace_agent + +logger = get_logger(__name__) + +MAX_ITERATIONS = 50 +LLM_TIMEOUT = 180 +_ITERATION_DELAY_SEC = 1 + +# The submit_source_model JSON for a real-world ontology can run several KB +# (17+ classes × multiple candidates + canonical_ids + join_keys + plan). +# A small ceiling silently truncates the call (finish_reason=length) and the +# dataclass validation fails with no clue to the LLM as to why. 100k removes +# the practical ceiling for any ontology size; you only pay for tokens +# actually generated, so the cost stays bounded by output complexity. +_MAX_TOKENS = 50000 + +_TRACE_NAME = "mapping_pge_planner" + + +# ===================================================== +# Tool aggregation +# ===================================================== +# +# The Planner uses every read tool the auto-mapping agent has — ontology, +# metadata, documents, execute_sql — *plus* the four planner-specific tools. +# It deliberately does NOT receive ``submit_entity_mapping`` / +# ``submit_relationship_mapping``: those belong to the Generator (Sprints 4 +# and 5). The Planner's only terminal tool is ``submit_source_model``. + +TOOL_DEFINITIONS: List[dict] = ( + [GET_METADATA_DEF, GET_DOCUMENTS_CONTEXT_DEF] + + ONTOLOGY_TOOL_DEFINITIONS + + SQL_TOOL_DEFINITIONS + + PLANNER_TOOL_DEFINITIONS +) + +TOOL_HANDLERS: Dict[str, Callable] = { + "get_metadata": tool_get_metadata, + "get_documents_context": tool_get_documents_context, + **ONTOLOGY_TOOL_HANDLERS, + **SQL_TOOL_HANDLERS, + **PLANNER_TOOL_HANDLERS, +} + + +# ===================================================== +# Data classes +# ===================================================== + + +@dataclass +class PlannerStep: + """One observable step of the Planner's execution. + + Mirrors :class:`agents.engine_base.AgentStep` but is scoped to the Planner + so the orchestrator (Sprint 7) can present a stage-specific timeline in + the UI. + """ + + step_type: str # tool_call | tool_result | output + content: str + tool_name: str = "" + duration_ms: int = 0 + + +@dataclass +class PlannerResult: + """Outcome of a single Planner invocation. + + ``source_model`` is populated only when the LLM successfully called + ``submit_source_model`` with a structurally-valid payload. ``error`` is + the short reason string when ``success`` is ``False``. + """ + + success: bool + source_model: Optional["SourceModel"] = None + steps: List[PlannerStep] = field(default_factory=list) + iterations: int = 0 + error: str = "" + usage: Dict[str, int] = field(default_factory=dict) + + +# ===================================================== +# System prompt +# ===================================================== + +SYSTEM_PROMPT = """\ +You are a senior data architect. Your job is to build a SourceModel that \ +bridges a set of source tables to an OWL ontology, so a downstream Generator \ +agent can mechanically emit entity- and relationship-mapping SQL. + +TOOLS +You have these tools available: + • get_ontology – load classes (with attributes) and object \ +properties to be mapped. + • get_metadata – load imported table schemas (full names, \ +columns, types). + • get_documents_context – load any imported domain documents (glossaries, \ +schema docs). + • sample_table – return up to N random rows so you can see \ +actual values, not just column types. Use when a column's role is unclear \ +from its name/type alone. + • column_value_overlap – measure |distinct(from) ∩ distinct(to)| / \ +|distinct(from)| for two bare COLUMNS. Use to VALIDATE a candidate join key \ +with real data — never propose a join_key on the strength of name similarity \ +alone. + • normalized_value_overlap – the same overlap metric, but each side is a \ +scalar SQL EXPRESSION. This is how you PROVE a canonical-key normalization: \ +when two tables for the same class have 0% raw overlap, propose a \ +normalization expression per table and confirm overlap_pct > 0 here BEFORE \ +you submit. A still-zero result means your expression is wrong — fix it. + • distinct_count – row / distinct / null counts plus is_unique \ +and is_complete flags. Use to confirm a candidate canonical-ID column is \ +actually unique and complete. + • execute_sql – escape hatch for any check the four tools above \ +do not cover. Use sparingly — prefer the focused tools. + • submit_source_model – TERMINAL. Call exactly once, when the \ +SourceModel is complete and you are ready to hand off to the Generator. + +WORKFLOW +1. Call get_ontology AND get_metadata first to see what needs mapping and \ +what data is available. +2. Call get_documents_context to pick up any pre-loaded domain documents — \ +they often disambiguate column semantics. +3. For each table, decide which ontology class(es) it could realise — these \ +become table_roles[].ontology_class_candidates with a confidence and a one- \ +sentence reason. +4. For each ontology class, decide which column serves as its canonical \ +identifier in each table — record under canonical_ids[]. When you are \ +uncertain, run distinct_count to confirm uniqueness/completeness. +5. For each pair of tables that should join (intra-trust FK or cross-source \ +value match), run column_value_overlap and only record join_keys[] when the \ +realised overlap_pct supports it. Use kind="same_trust_fk" for FK joins and \ +kind="cross_source_value_match" for value-matched joins across sources. \ +For any class mapped to 2+ tables, follow CANONICAL-KEY NORMALIZATION below \ +and PROVE the chosen keys overlap with normalized_value_overlap. +6. Build mapping_plan.entity_order so that BASE classes come first \ +(i.e. classes that are referenced by other classes through object properties \ +should be mapped before their referencers). Build \ +mapping_plan.relationship_order so that, by the time each relationship is \ +attempted, BOTH its domain and range classes have already appeared in \ +entity_order. List anything you cannot reasonably map under \ +mapping_plan.skip[] with a short reason. +7. Finally, call submit_source_model exactly once with the full JSON. The \ +call returns success=true when the model is structurally valid; if it \ +returns success=false, fix the indicated problem and call it again. + +CANONICAL-KEY NORMALIZATION (CRITICAL — this is the #1 cause of relationship dangling) +For any class whose canonical_id lists MORE THAN ONE table, run \ +column_value_overlap on a representative column pair to see whether the raw \ +values already share a format: + + • If overlap_pct > 0 → values are in compatible formats. Record bare \ +column names in canonical_column_per_table (e.g. ``"MOTHER_NHS_NO"``). \ +A UNION across the tables produces a coherent ID universe. + + • If overlap_pct == 0 → DO NOT conclude these are "different" or \ +"trust-scoped" entities. When two tables both map to the SAME ontology \ +class, 0% overlap almost always means the SAME real-world key wrapped in \ +DIFFERENT trust-local encodings (prefixes, suffixes, embedded sub-IDs). \ +Leaving them disjoint makes every relationship pointing at this class 100% \ +dangle — that is a FAILURE, not an acceptable outcome. You MUST normalize: + + STEP 1 — sample_table BOTH columns and read the raw values. Look for a \ +shared embedded substring across the trusts — a stable inner identifier \ +(UUID, NHS number, ``...-preg-`` core) that appears in every trust's \ +value with only the surrounding prefix/suffix differing. + + STEP 2 — write ONE scalar SQL expression PER TABLE that strips the \ +trust-specific wrapping and exposes that shared core in an identical form. \ +Prefer extracting the shared core over stripping a single known prefix \ +(extraction is robust to multiple prefixes). When matching a hex/UUID core, \ +ALWAYS anchor the regex with a leading character class so a preceding dash \ +is not captured: + ✗ WRONG: regexp_extract(EPISODE_ID, '([a-f0-9-]+-preg-[0-9]+)', 1) + → returns "--preg-1" (leading dash) — will NOT match + ✓ RIGHT: regexp_extract(EPISODE_ID, '([a-f0-9][a-f0-9-]+-preg-[0-9]+)', 1) + → returns "-preg-1" + + STEP 3 — for a DERIVED / child key (e.g. a Delivery, Baby or Apgar that \ +hangs off a pregnancy), DO NOT concatenate a suffix onto the RAW prefixed \ +local id — that re-introduces the trust prefix and the keys stay disjoint. \ +Extract the shared core FIRST, then append the role suffix, so every trust \ +yields the identical synthetic key: + ✗ WRONG: trust_a "CONCAT(EPISODE_ID, '-del')" (→ STA--preg-1-del) + trust_b "delivery_id" (→ BUH-DEL-BUH--preg-1) + ✓ RIGHT: trust_a "CONCAT(regexp_extract(EPISODE_ID, '([a-f0-9][a-f0-9-]+-preg-[0-9]+)', 1), '-del')" + trust_b "CONCAT(regexp_extract(delivery_id, '([a-f0-9][a-f0-9-]+-preg-[0-9]+)', 1), '-del')" + (both → -preg-1-del) + + STEP 4 — PROVE IT. Call normalized_value_overlap with your two \ +expressions. It MUST return overlap_pct > 0. If it is still 0, your \ +expressions land in different value spaces — go back to STEP 1 and fix them. \ +Do NOT call submit_source_model with an unverified normalization. + + (If, after sampling, a table genuinely cannot expose the shared core at \ +all, omit that table from canonical_column_per_table and note why — but this \ +is rare; exhaust STEP 1–4 first.) + + • Whatever expression you record, the EntityGenerator drops it verbatim \ +into the SELECT aliased AS ID. Bare column names and SQL expressions are \ +both valid here. + + • Always update format_note to one sentence describing what the canonical \ +key looks like (e.g. ``"-preg- core extracted from each \ +trust's local pregnancy id"``). Downstream agents read this. + +SOURCEMODEL JSON SCHEMA (these key names are LOAD-BEARING — do not improvise) +The `model` argument to submit_source_model has exactly this shape: + +{ + "table_roles": [ + { + "table": "", // STRING — required key is "table" + "ontology_class_candidates": [ + {"uri": "", "confidence": 0.0, "reason": ""} + ] + } + ], + "canonical_ids": [ + { + "ontology_class": "", // STRING — required key is "ontology_class" + // VALUES may be either a bare column name OR a SQL expression that + // produces the canonical key for that table. Use a SQL expression + // when raw column values across the listed tables are in different + // formats (see CANONICAL-KEY NORMALIZATION below). + "canonical_column_per_table": {"": ""}, + "format_note": "" + } + ], + "join_keys": [ + { + "from_ref": "
.", // STRING — required key is "from_ref" + "to_ref": "
.", // STRING — required key is "to_ref" + "confidence": 0.0, + "overlap_pct": 0.0, + "kind": "same_trust_fk" // or "cross_source_value_match" + } + ], + "mapping_plan": { + "entity_order": ["", "..."], + "relationship_order": ["", "..."], + "skip": [ + {"item": "", "reason": ""} // required keys: "item", "reason" + ] + } +} + +Key-name traps to avoid: +• Use "table" (not "name", "table_name", "uri") in each table_roles[] entry. +• Use "ontology_class" (not "class", "uri") in each canonical_ids[] entry. +• Use "from_ref" / "to_ref" (not "from" / "to" / "source" / "target") in each join_keys[] entry. +• Use "item" (not "uri", "property") in each mapping_plan.skip[] entry. + +INVARIANTS (the orchestrator will enforce these) +• Every URI in entity_order MUST exist in the ontology AND have at least one \ +candidate in table_roles[].ontology_class_candidates. +• Every URI in relationship_order MUST reference a property whose domain \ +class and range class both appear in entity_order at an EARLIER position. +• All confidence values are floats in [0.0, 1.0]. +• kind on each join_key is EXACTLY one of: "same_trust_fk", \ +"cross_source_value_match". +• Call submit_source_model EXACTLY ONCE, at the end. Do not emit a free-text \ +summary afterwards — submit_source_model is the terminal step. + +GENERAL RULES +• Prefer the focused tools (sample_table, column_value_overlap, \ +normalized_value_overlap, distinct_count) over execute_sql. +• Validate candidate join keys with column_value_overlap before adding them \ +to join_keys[]. +• You may batch multiple independent tool calls in a single response. +• Only ever pass row-returning queries (SELECT / WITH …) to execute_sql. +""" + + +# ===================================================== +# Internal helpers +# ===================================================== + + +def _build_user_prompt( + entities: List[dict], relationships: List[dict], n_tables: int +) -> str: + parts = [ + ( + f"Build a SourceModel for {n_tables} table(s), {len(entities)} ontology " + f"entity/entities, and {len(relationships)} relationship(s). " + "Start by calling get_ontology, get_metadata, and get_documents_context." + ) + ] + if entities: + names = ", ".join(e.get("name", "?") for e in entities) + parts.append(f"Entities in scope: {names}") + if relationships: + names = ", ".join(r.get("name", "?") for r in relationships) + parts.append(f"Relationships in scope: {names}") + prompt = "\n".join(parts) + logger.debug("_build_user_prompt (%d chars):\n%s", len(prompt), prompt) + return prompt + + +# ===================================================== +# Public entry point +# ===================================================== + + +@trace_agent(name="mapping_pge_planner") +def run_planner( + host: str, + token: str, + endpoint_name: str, + client: Any, + metadata: dict, + ontology: dict, + *, + documents: Optional[list] = None, + on_step: Optional[Callable[[str, int], None]] = None, + max_iterations: int = MAX_ITERATIONS, +) -> PlannerResult: + """Run the Planner agent. + + The Planner autonomously produces a :class:`SourceModel` by exploring the + ontology, metadata, documents, and source data via tool calls. It + terminates as soon as it submits a structurally-valid SourceModel via the + terminal ``submit_source_model`` tool. + + Args: + host: Databricks workspace URL. + token: Bearer token for the serving endpoint. + endpoint_name: Foundation Model serving endpoint name. + client: Databricks SQL client (must expose ``execute_query(sql)``). + metadata: Imported domain metadata (``{"tables": [...]}``). + ontology: Imported ontology (``{"entities": [...], "relationships": [...]}``). + documents: Optional pre-loaded domain documents. + on_step: Optional progress callback ``(msg, pct)`` for UI updates. + max_iterations: Upper bound on tool-call iterations (default 25). + + Returns: + A :class:`PlannerResult`. ``success`` is True iff a SourceModel was + successfully submitted; in that case ``source_model`` holds the + validated dataclass. On failure, ``error`` explains why and + ``source_model`` is None. + """ + iteration_limit = max_iterations if max_iterations is not None else MAX_ITERATIONS + + entities = (ontology or {}).get("entities", []) + relationships = (ontology or {}).get("relationships", []) + n_tables = len((metadata or {}).get("tables", [])) + + logger.info( + "===== PLANNER START ===== endpoint=%s, tables=%d, entities=%d, relationships=%d, max_iter=%d", + endpoint_name, + n_tables, + len(entities), + len(relationships), + iteration_limit, + ) + + ctx = ToolContext( + host=host.rstrip("/"), + token=token, + client=client, + metadata=metadata or {}, + ontology=ontology or {}, + documents=list(documents or []), + ) + + result = PlannerResult(success=False) + + user_prompt = _build_user_prompt(entities, relationships, n_tables) + messages: List[dict] = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ] + logger.info( + "Planner conversation initialized: system=%d chars, user=%d chars", + len(SYSTEM_PROMPT), + len(user_prompt), + ) + + total_usage: Dict[str, int] = {"prompt_tokens": 0, "completion_tokens": 0} + + def _progress_pct(iteration_idx: int) -> int: + # Linear ramp from 5 → 95 across the iteration budget. The terminal + # submit_source_model call is what sets 100. + ratio = (iteration_idx + 1) / max(iteration_limit, 1) + return min(5 + int(ratio * 90), 95) + + def notify(msg: str, *, pct: Optional[int] = None): + actual_pct = pct if pct is not None else 5 + logger.info("PLANNER STEP [%d%%] %s", actual_pct, msg) + if on_step: + on_step(msg, actual_pct) + + notify("Starting planner…", pct=1) + + # ------------------------------------------------------------------ + # Agent loop + # ------------------------------------------------------------------ + for iteration in range(iteration_limit): + # Rate-limit mitigation — same 3s delay as the legacy mapping agent. + if iteration > 0: + logger.debug( + "Iteration %d: waiting %ds before LLM call (rate limit mitigation)", + iteration + 1, + _ITERATION_DELAY_SEC, + ) + time.sleep(_ITERATION_DELAY_SEC) + + current_iteration = iteration + 1 + pct = _progress_pct(iteration) + logger.info( + "----- Planner iteration %d/%d — %d messages, source_model=%s -----", + current_iteration, + iteration_limit, + len(messages), + "set" if ctx.source_model is not None else "unset", + ) + notify(f"Planning iteration {current_iteration}/{iteration_limit}…", pct=pct) + + t0 = time.time() + try: + llm_response = call_serving_endpoint( + host, + token, + endpoint_name, + messages, + tools=TOOL_DEFINITIONS, + max_tokens=_MAX_TOKENS, + temperature=0.1, + timeout=LLM_TIMEOUT, + trace_name=_TRACE_NAME, + ) + except requests.exceptions.HTTPError as exc: + status = exc.response.status_code if exc.response is not None else "?" + logger.warning( + "Planner iteration %d: HTTPError status=%s", current_iteration, status + ) + logger.debug( + "Planner iteration %d: HTTPError body: %.500s", + current_iteration, + exc.response.text if exc.response is not None else "N/A", + ) + # Tools are non-negotiable for the Planner — no single-shot fallback. + if exc.response is not None and status in (400, 422): + result.error = "LLM endpoint does not support function calling" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "Planner: endpoint refused tools — cannot produce a SourceModel" + ) + return result + result.error = f"LLM request failed: {exc}" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "Planner: LLM request failed at iteration %d: %s", + current_iteration, + exc, + ) + return result + except requests.exceptions.ReadTimeout: + result.error = f"LLM request timed out after {LLM_TIMEOUT}s" + result.iterations = current_iteration + result.usage = total_usage + logger.error("Planner: timeout at iteration %d", current_iteration) + return result + except requests.exceptions.RequestException as exc: + result.error = f"LLM request failed: {exc}" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "Planner: request exception at iteration %d: %s", + current_iteration, + exc, + ) + return result + + elapsed_ms = int((time.time() - t0) * 1000) + logger.info( + "Planner iteration %d: LLM responded in %dms", current_iteration, elapsed_ms + ) + + accumulate_usage(total_usage, llm_response.get("usage", {})) + + choice = llm_response.get("choices", [{}])[0] + finish_reason = choice.get("finish_reason", "?") + message = choice.get("message", {}) + tool_calls = message.get("tool_calls", []) + has_content = bool(message.get("content")) + logger.info( + "Planner iteration %d: finish_reason=%s, tool_calls=%d, has_content=%s", + current_iteration, + finish_reason, + len(tool_calls), + has_content, + ) + # A tool call truncated by the max_tokens ceiling produces malformed + # arguments and the tool can't recover. Flag it loudly so future runs + # don't silently waste iterations resubmitting the same broken JSON. + if finish_reason == "length" and tool_calls: + logger.error( + "Planner iteration %d: finish_reason=length on a tool call — " + "arguments were likely truncated. Consider bumping max_tokens.", + current_iteration, + ) + + if not tool_calls: + # The Planner must end with submit_source_model, not free text. + # If we see text without a terminal call, that's a failure. + content = (message.get("content") or "")[:500] + logger.warning( + "Planner iteration %d: produced text without submitting source model — %d chars", + current_iteration, + len(message.get("content") or ""), + ) + result.steps.append( + PlannerStep( + step_type="output", + content=content, + duration_ms=elapsed_ms, + ) + ) + result.error = "planner produced text without submitting source model" + result.iterations = current_iteration + result.usage = total_usage + notify("Planner produced text without submitting source model.", pct=pct) + return result + + # Tool-call branch — dispatch each call and accumulate steps. + logger.info( + "Planner iteration %d: processing %d tool call(s): [%s]", + current_iteration, + len(tool_calls), + ", ".join( + tc.get("function", {}).get("name", "?") for tc in tool_calls + ), + ) + messages.append(message) + + terminal_success = False + for tc_idx, tc in enumerate(tool_calls, 1): + func = tc.get("function", {}) + tool_name = func.get("name", "") + raw_args = func.get("arguments", "{}") + tool_id = tc.get("id", "") + + try: + arguments = json.loads(raw_args) + except json.JSONDecodeError: + arguments = {} + + logger.info( + "Planner iteration %d: calling tool '%s' (%d/%d)", + current_iteration, + tool_name, + tc_idx, + len(tool_calls), + ) + + # Human-readable progress messages per tool — same pattern as + # the legacy mapping agent for UI consistency. + if tool_name == "submit_source_model": + notify("Submitting source model…", pct=pct) + elif tool_name == "get_metadata": + notify("Retrieving table metadata…", pct=pct) + elif tool_name == "get_ontology": + notify("Retrieving ontology…", pct=pct) + elif tool_name == "get_documents_context": + notify("Retrieving documents…", pct=pct) + elif tool_name == "sample_table": + fn = arguments.get("full_name", "?") + notify(f"Sampling {fn}…", pct=pct) + elif tool_name == "column_value_overlap": + notify("Checking column overlap…", pct=pct) + elif tool_name == "normalized_value_overlap": + notify("Verifying canonical-key normalization…", pct=pct) + elif tool_name == "distinct_count": + notify("Checking distinct count…", pct=pct) + elif tool_name == "execute_sql": + sql_preview = arguments.get("sql", "")[:80] + notify(f"Running SQL: {sql_preview}…", pct=pct) + else: + notify(f"Calling {tool_name}…", pct=pct) + + result.steps.append( + PlannerStep( + step_type="tool_call", + content=json.dumps(arguments, default=str)[:500], + tool_name=tool_name, + ) + ) + + t1 = time.time() + tool_result = dispatch_tool( + TOOL_HANDLERS, ctx, tool_name, arguments, trace_name=_TRACE_NAME + ) + tool_ms = int((time.time() - t1) * 1000) + + logger.info( + "Planner iteration %d: tool '%s' returned %d chars in %dms", + current_iteration, + tool_name, + len(tool_result), + tool_ms, + ) + + result.steps.append( + PlannerStep( + step_type="tool_result", + content=( + (tool_result[:500] + "…") + if len(tool_result) > 500 + else tool_result + ), + tool_name=tool_name, + duration_ms=tool_ms, + ) + ) + + messages.append( + { + "role": "tool", + "tool_call_id": tool_id, + "content": tool_result, + } + ) + + # Detect terminal success: submit_source_model returned success=True + # *and* stamped a SourceModel onto the context. We break *after* + # appending the tool result so the orchestrator sees a complete + # message trail in conversation/replay. + if tool_name == "submit_source_model": + try: + parsed = json.loads(tool_result) + except json.JSONDecodeError: + parsed = {} + if parsed.get("success") is True and ctx.source_model is not None: + terminal_success = True + logger.info( + "Planner iteration %d: submit_source_model succeeded — terminating", + current_iteration, + ) + + if terminal_success: + result.success = True + result.source_model = ctx.source_model + result.iterations = current_iteration + result.usage = total_usage + logger.info( + "===== PLANNER COMPLETE ===== iterations=%d, " + "prompt_tokens=%d, completion_tokens=%d", + result.iterations, + total_usage["prompt_tokens"], + total_usage["completion_tokens"], + ) + notify("Planner completed!", pct=100) + return result + + # Exhausted the iteration budget without ever calling submit_source_model + # successfully (or the LLM kept calling other tools forever). + result.iterations = iteration_limit + result.usage = total_usage + result.error = "planner exhausted iteration budget without submitting source model" + logger.error("===== PLANNER FAILED ===== %s", result.error) + notify(result.error, pct=95) + return result diff --git a/src/agents/tools/context.py b/src/agents/tools/context.py index 3b88df82..6edf3919 100644 --- a/src/agents/tools/context.py +++ b/src/agents/tools/context.py @@ -6,7 +6,10 @@ """ from dataclasses import dataclass, field -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + from agents.agent_mapping_pge.contracts import EvalReport, SourceModel @dataclass @@ -53,3 +56,13 @@ class ToolContext: dtwin_registry_params: dict = field(default_factory=dict) dtwin_domain_name: str = "" dtwin_ontology_labels: dict = field(default_factory=dict) # uri/name → display label + + # Mapping PGE planner output (``agent_mapping_pge``) – populated by the + # ``submit_source_model`` terminal tool. Forward-ref string typing avoids a + # circular import between ``agents.tools`` and ``agents.agent_mapping_pge``. + source_model: Optional["SourceModel"] = None + + # Mapping PGE semantic critic output (``agent_mapping_pge``) – populated by + # the ``submit_evaluation`` terminal tool of the Sprint 6 Critic agent. + # Same forward-ref pattern as ``source_model`` to avoid a circular import. + semantic_eval_report: Optional["EvalReport"] = None diff --git a/src/agents/tools/evaluation.py b/src/agents/tools/evaluation.py new file mode 100644 index 00000000..e4f62f94 --- /dev/null +++ b/src/agents/tools/evaluation.py @@ -0,0 +1,205 @@ +"""Terminal tool for the mapping-PGE Semantic Critic (Sprint 6). + +The Critic audits ONE submitted mapping for semantic correctness after the +deterministic (stage-1) evaluator has already passed. It submits its verdict +through ``submit_evaluation`` — the terminal tool defined here — which +constructs an :class:`EvalReport` (stage="semantic") and stamps it onto +``ctx.semantic_eval_report``. + +This module deliberately mirrors the shape of the other terminal tools +(``submit_source_model``, ``submit_entity_mapping``, …) — pure-Python handler +with a JSON-schema definition for OpenAI function calling, exported via +``EVALUATION_TOOL_DEFINITIONS`` / ``EVALUATION_TOOL_HANDLERS`` aggregates. +""" + +import json +from typing import Callable, Dict, List, Optional + +from back.core.logging import get_logger +from agents.tools.context import ToolContext + +logger = get_logger(__name__) + + +# ===================================================== +# OpenAI function-calling definition +# ===================================================== + +SUBMIT_EVALUATION_DEF: dict = { + "type": "function", + "function": { + "name": "submit_evaluation", + "description": ( + "Submit the final semantic evaluation. Terminal tool — call exactly once " + "when you have a confident verdict. status MUST be 'PASS' or 'FAIL'. " + "If failing, populate failures[] with at least one entry. " + "Set bubble_to_planner=true ONLY when the wrong TABLE was chosen " + "(not just a wrong column within the right table)." + ), + "parameters": { + "type": "object", + "properties": { + "status": {"type": "string", "enum": ["PASS", "FAIL"]}, + "failures": { + "type": "array", + "items": { + "type": "object", + "properties": { + "check": {"type": "string"}, + "expected": {"type": "string"}, + "observed": {"type": "string"}, + "hint": {"type": "string"}, + }, + "required": ["check", "expected", "observed", "hint"], + }, + "description": "Empty when status is PASS.", + }, + "bubble_to_planner": {"type": "boolean"}, + "reasoning": { + "type": "string", + "description": "One-paragraph summary of the audit reasoning.", + }, + }, + "required": ["status"], + }, + }, +} + + +# ===================================================== +# Handler +# ===================================================== + + +def tool_submit_evaluation( + ctx: ToolContext, + *, + status: str = "", + failures: Optional[list] = None, + bubble_to_planner: bool = False, + reasoning: str = "", + **_kwargs, +) -> str: + """Construct an EvalReport from the critic's submission and store on ctx. + + Contract: + * ``status`` MUST be one of ``"PASS"`` or ``"FAIL"`` — anything else is + rejected as a JSON error so the agent loop can coach the LLM and + continue (it does NOT terminate the loop). + * On ``FAIL`` with an empty ``failures`` list, a generic + ``semantic_audit`` failure is synthesised so the resulting report is + coherent (status=FAIL <=> failures non-empty, matching + :func:`evaluator.report.build_report` semantics). + * ``bubble_to_planner=True`` is demoted to False when status is PASS — + same invariant the deterministic evaluator's :func:`build_report` + enforces (a passing evaluation should not escalate). + """ + logger.info( + "tool_submit_evaluation: status=%s, failures=%d, bubble=%s, reasoning=%d chars", + status, + len(failures or []), + bubble_to_planner, + len(reasoning or ""), + ) + + if status not in ("PASS", "FAIL"): + logger.warning("tool_submit_evaluation: invalid status=%r", status) + return json.dumps( + { + "success": False, + "error": f"invalid status: {status!r} (must be PASS or FAIL)", + } + ) + + # Lazy import — these contracts live in agent_mapping_pge and importing + # them at module load time would create a cycle through + # ``agents.tools.context``. + from agents.agent_mapping_pge.contracts import EvalFailure, EvalReport + + eval_failures: List[EvalFailure] = [] + for f in failures or []: + if not isinstance(f, dict): + continue + eval_failures.append( + EvalFailure( + kind="semantic", + check=str(f.get("check") or ""), + expected=str(f.get("expected") or ""), + observed=str(f.get("observed") or ""), + hint=str(f.get("hint") or ""), + ) + ) + + # status=PASS <=> failures empty. If the LLM submitted both, clamp the + # failures list and warn — keeping a passing report internally coherent. + if status == "PASS" and eval_failures: + logger.warning( + "tool_submit_evaluation: status=PASS with %d failures — clamping to []", + len(eval_failures), + ) + eval_failures = [] + + # If status=FAIL but no failures, synthesise a generic one so the report + # is coherent (status=FAIL <=> failures non-empty). + if status == "FAIL" and not eval_failures: + logger.debug( + "tool_submit_evaluation: synthesising semantic_audit failure for " + "FAIL with no failures[]" + ) + eval_failures.append( + EvalFailure( + kind="semantic", + check="semantic_audit", + expected="PASS", + observed="FAIL", + hint=reasoning or "critic returned FAIL without specific failures", + ) + ) + + # If status=PASS but bubble flag is True, demote — matches + # ``build_report``'s behaviour and the documented invariant: a passing + # evaluation does not escalate to the Planner. + if status == "PASS" and bubble_to_planner: + logger.warning( + "tool_submit_evaluation: bubble_to_planner=True with status=PASS — " + "demoting to False" + ) + bubble_to_planner = False + + metrics: Dict[str, str] = {"reasoning": reasoning} if reasoning else {} + + report = EvalReport( + status=status, + stage="semantic", + metrics=metrics, + failures=eval_failures, + bubble_to_planner=bool(bubble_to_planner), + ) + ctx.semantic_eval_report = report + + logger.info( + "tool_submit_evaluation: stored EvalReport status=%s, failures=%d, bubble=%s", + report.status, + len(report.failures), + report.bubble_to_planner, + ) + + return json.dumps( + { + "success": True, + "status": status, + "failures": len(eval_failures), + "bubble_to_planner": report.bubble_to_planner, + } + ) + + +# ===================================================== +# Aggregates +# ===================================================== + +EVALUATION_TOOL_DEFINITIONS: List[dict] = [SUBMIT_EVALUATION_DEF] + +EVALUATION_TOOL_HANDLERS: Dict[str, Callable] = { + "submit_evaluation": tool_submit_evaluation, +} diff --git a/src/agents/tools/mapping.py b/src/agents/tools/mapping.py index f2eaa19a..87dbecc0 100644 --- a/src/agents/tools/mapping.py +++ b/src/agents/tools/mapping.py @@ -28,9 +28,20 @@ def tool_submit_entity_mapping( id_column: str = "", label_column: str = "", attribute_mappings: Optional[dict] = None, + unmapped_attributes: Optional[list] = None, **_kwargs, ) -> str: - """Record a completed entity mapping.""" + """Record a completed entity mapping. + + ``unmapped_attributes`` lets the Generator stage declare ontology attributes + it intentionally did not map to a column, with a one-sentence ``reason``. + Items may be either bare strings (attribute name only) or dicts of shape + ``{"name": str, "reason": str}`` — the richer dict form is preferred for + downstream consumption but bare strings round-trip too. Anything else is + coerced to a string for safety. This enforces the PGE "no silent drops" + invariant: every ontology attribute is either in ``attribute_mappings`` or + in ``unmapped_attributes``. + """ logger.info("tool_submit_entity_mapping: '%s' (uri=%s)", class_name, class_uri) if not class_uri or not sql_query: logger.warning("tool_submit_entity_mapping: missing required fields") @@ -42,6 +53,22 @@ def tool_submit_entity_mapping( .rstrip(";") ) + # Normalise ``unmapped_attributes`` — accept either form, persist as-is for + # dicts, leave bare strings as strings (validation/coverage is downstream). + normalised_unmapped: List = [] + for item in unmapped_attributes or []: + if isinstance(item, dict) and "name" in item: + normalised_unmapped.append( + { + "name": str(item.get("name", "")), + "reason": str(item.get("reason", "")), + } + ) + elif isinstance(item, str): + normalised_unmapped.append(item) + else: + normalised_unmapped.append(str(item)) + mapping = { "ontology_class": class_uri, "class_name": class_name, @@ -49,6 +76,7 @@ def tool_submit_entity_mapping( "id_column": id_column, "label_column": label_column, "attribute_mappings": attribute_mappings or {}, + "unmapped_attributes": normalised_unmapped, } logger.debug( @@ -78,12 +106,14 @@ def tool_submit_entity_mapping( logger.debug("tool_submit_entity_mapping: appended new mapping") mapped_attrs = len(mapping["attribute_mappings"]) + unmapped_count = len(mapping["unmapped_attributes"]) logger.info( - "tool_submit_entity_mapping: '%s' recorded — ID=%s, Label=%s, %d attr(s) mapped", + "tool_submit_entity_mapping: '%s' recorded — ID=%s, Label=%s, %d attr(s) mapped, %d unmapped", class_name, id_column, label_column, mapped_attrs, + unmapped_count, ) return json.dumps( { @@ -92,6 +122,7 @@ def tool_submit_entity_mapping( "id_column": id_column, "label_column": label_column, "attributes_mapped": mapped_attrs, + "attributes_unmapped": unmapped_count, "total_entity_mappings": len(ctx.entity_mappings), } ) @@ -243,6 +274,30 @@ def _extract_label(value: str) -> str: ), "additionalProperties": {"type": "string"}, }, + "unmapped_attributes": { + "type": "array", + "description": ( + "Ontology attributes you intentionally did NOT map to a column, " + "each with a one-sentence reason. Use this to satisfy the " + 'no-silent-drops invariant. Preferred shape: ' + '[{"name": "apgarScore", "reason": "absent from source table"}]. ' + "Bare strings are also accepted but discouraged." + ), + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Ontology attribute name.", + }, + "reason": { + "type": "string", + "description": "Why this attribute was not mapped.", + }, + }, + "required": ["name", "reason"], + }, + }, }, "required": [ "class_uri", @@ -279,11 +334,19 @@ def _extract_label(value: str) -> str: }, "source_id_column": { "type": "string", - "description": "Column name for the source entity identifier.", + "description": ( + "The output-column alias in sql_query that holds the " + "source id — alias it AS source_id and pass " + '"source_id" here (NOT the entity\'s id_column).' + ), }, "target_id_column": { "type": "string", - "description": "Column name for the target entity identifier.", + "description": ( + "The output-column alias in sql_query that holds the " + "target id — alias it AS target_id and pass " + '"target_id" here (NOT the entity\'s id_column).' + ), }, "domain": { "type": "string", @@ -317,3 +380,10 @@ def _extract_label(value: str) -> str: "submit_entity_mapping": tool_submit_entity_mapping, "submit_relationship_mapping": tool_submit_relationship_mapping, } + +# Name-indexed view of MAPPING_TOOL_DEFINITIONS so callers needing a single +# definition (e.g. the EntityGenerator, which only exposes the entity submit +# tool) can look it up by name without re-scanning the list. +MAPPING_TOOL_DEFINITIONS_BY_NAME: Dict[str, dict] = { + d["function"]["name"]: d for d in MAPPING_TOOL_DEFINITIONS +} diff --git a/src/agents/tools/planner.py b/src/agents/tools/planner.py new file mode 100644 index 00000000..a1de9af9 --- /dev/null +++ b/src/agents/tools/planner.py @@ -0,0 +1,707 @@ +""" +Planner tools – used by the mapping-PGE Planner agent (Sprint 2+). + +Exposes the OpenAI function-calling tools that let the Planner LLM probe +source tables and submit a validated ``SourceModel`` artefact: + +* ``sample_table`` — N random rows from a table (n capped at 100). +* ``column_value_overlap`` — one-sided distinct-value overlap between two columns. +* ``normalized_value_overlap`` — same metric, but each side is a scalar SQL + expression, so canonical-key normalizations can be proven before commit. +* ``distinct_count`` — uniqueness / completeness of a candidate canonical id. +* ``submit_source_model`` — terminal tool: validates the candidate SourceModel + JSON against :class:`agents.agent_mapping_pge.contracts.SourceModel` and stores + the dataclass instance on :attr:`ToolContext.source_model`. + +All handlers return JSON strings (same convention as ``agents.tools.sql``) +and stringify scalar values for the LLM-facing surface. +""" + +import json +import re +from typing import Any, Callable, Dict, List, Optional, Tuple + +from back.core.logging import get_logger +from agents.tools.context import ToolContext + +logger = get_logger(__name__) + + +# Cap on ``n`` in ``sample_table`` to keep the LLM context bounded. +_SAMPLE_TABLE_MAX_N = 100 +_SAMPLE_TABLE_DEFAULT_N = 20 + + +# Permissive but injection-safe SQL identifier shape. We allow dots (for +# fully-qualified ``catalog.schema.table``) and backticks (for quoted +# identifiers), plus the usual alphanumerics + underscore. Anything else +# — semicolons, whitespace, quotes, comment markers — is rejected. +_IDENTIFIER_RE = re.compile(r"^[A-Za-z0-9_.`]+$") + + +# SQL keywords whose presence in a "normalization expression" indicates the +# string is no longer a scalar expression but a smuggled clause / subquery / +# DDL. A legitimate canonical-key expression (regexp_extract, regexp_replace, +# concat, substring, lower, upper, trim, coalesce, ||, string literals) needs +# none of these. Matched case-insensitively as whole words. +_EXPR_FORBIDDEN_WORDS = frozenset( + { + "select", + "from", + "where", + "join", + "union", + "intersect", + "except", + "insert", + "update", + "delete", + "drop", + "alter", + "create", + "grant", + "revoke", + "table", + "into", + "exec", + "execute", + "call", + "merge", + "values", + "having", + "group", + "order", + "limit", + } +) +_EXPR_WORD_RE = re.compile(r"[A-Za-z_][A-Za-z0-9_]*") + + +def _validate_safe_expression(expr: str, *, role: str) -> Optional[str]: + """Return None if ``expr`` is a safe scalar SQL expression; else an error. + + Unlike :func:`_validate_identifier`, this permits the parentheses, commas, + quotes and operators a canonical-key normalization needs (e.g. + ``regexp_extract(EPISODE_ID, '([a-f0-9][a-f0-9-]+-preg-\\d+)', 1)`` or + ``concat(regexp_extract(delivery_id, '...', 1), '-del')``). It still gets + interpolated into SQL via an f-string, so it is gated against the obvious + injection vectors: statement terminators, comment markers, and any SQL + keyword that would turn the scalar into a clause/subquery/DDL. + """ + if not isinstance(expr, str) or not expr.strip(): + return f"invalid {role}: must be a non-empty string" + if ";" in expr or "--" in expr or "/*" in expr or "*/" in expr: + return ( + f"invalid {role}: must not contain ';' or SQL comment markers " + f"(got {expr!r})" + ) + bad = sorted( + { + w.lower() + for w in _EXPR_WORD_RE.findall(expr) + if w.lower() in _EXPR_FORBIDDEN_WORDS + } + ) + if bad: + return ( + f"invalid {role}: a canonical-key expression must be a single scalar " + f"expression, not a clause/subquery. Forbidden keyword(s): " + f"{', '.join(bad)} (got {expr!r})" + ) + return None + + +def _validate_identifier(name: str, *, role: str) -> Optional[str]: + """Return None if ``name`` is a valid SQL identifier; else an error message. + + Used to gate identifiers that get interpolated into SQL via f-strings. + Even though today's callers are LLMs (not untrusted users), a hallucinated + identifier like ``t; DROP TABLE x`` or ``nhs FROM secrets--`` would + otherwise execute. + """ + if not isinstance(name, str) or not _IDENTIFIER_RE.fullmatch(name): + return f"invalid {role}: {name!r}" + return None + + +def _run_query( + ctx: ToolContext, + sql: str, + *, + tool_name: str, +) -> Tuple[Optional[List[Dict[str, Any]]], Optional[str]]: + """Execute the SQL via the client. Returns ``(rows, None)`` on success, + ``(None, error_str)`` on failure. On failure the SQL is logged at ERROR + level alongside the exception (previously only at DEBUG). + """ + try: + result = ctx.client.execute_query(sql) + return result, None + except Exception as exc: + logger.error( + "%s: query failed: %s\nSQL: %s", tool_name, exc, sql, exc_info=True + ) + return None, str(exc) + + +# ===================================================== +# Tool implementations +# ===================================================== + + +def tool_sample_table( + ctx: ToolContext, *, full_name: str = "", n: Any = _SAMPLE_TABLE_DEFAULT_N, **_kwargs +) -> str: + """Return N random sample rows from ``full_name`` so the agent can see + real values (not just column types). ``n`` is capped at 100. + """ + logger.info("tool_sample_table: full_name=%s, n=%s", full_name, n) + if not full_name: + return json.dumps({"success": False, "error": "full_name is required"}) + + err = _validate_identifier(full_name, role="full_name") + if err is not None: + return json.dumps({"success": False, "error": err}) + + # Strict ``n`` parsing: a malformed value is a tool-call error, not a + # silent fallback. The default (when ``n`` is omitted) is already the int + # ``_SAMPLE_TABLE_DEFAULT_N``, so ``int(n)`` is a no-op in that case. + try: + n_int = int(n) + except (TypeError, ValueError): + return json.dumps({"success": False, "error": f"invalid n: {n!r}"}) + capped_n = max(1, min(n_int, _SAMPLE_TABLE_MAX_N)) + + sql = f"SELECT * FROM {full_name} ORDER BY RAND() LIMIT {capped_n}" + logger.debug("tool_sample_table: SQL=%s", sql) + + rows, err = _run_query(ctx, sql, tool_name="tool_sample_table") + if err is not None: + return json.dumps({"success": False, "error": err}) + + rows = rows or [] + columns: List[str] = list(rows[0].keys()) if rows else [] + stringified_rows: List[List[Optional[str]]] = [] + for row in rows: + stringified_rows.append( + [str(row[c]) if row.get(c) is not None else None for c in columns] + ) + logger.info( + "tool_sample_table: %d row(s) × %d column(s)", + len(stringified_rows), + len(columns), + ) + return json.dumps( + { + "success": True, + "columns": columns, + "rows": stringified_rows, + "row_count": len(stringified_rows), + } + ) + + +def tool_column_value_overlap( + ctx: ToolContext, + *, + from_table: str = "", + from_column: str = "", + to_table: str = "", + to_column: str = "", + **_kwargs, +) -> str: + """Compute the one-sided overlap + ``|distinct(from) ∩ distinct(to)| / |distinct(from)|``. + + The numerator dedupes ``from`` before intersecting. Returns 0.0 (and a + note) when ``from_distinct_count`` is zero to avoid division by zero. + """ + logger.info( + "tool_column_value_overlap: %s.%s ↔ %s.%s", + from_table, + from_column, + to_table, + to_column, + ) + if not (from_table and from_column and to_table and to_column): + return json.dumps( + { + "success": False, + "error": "from_table, from_column, to_table, to_column are all required", + } + ) + + for value, role in ( + (from_table, "from_table"), + (from_column, "from_column"), + (to_table, "to_table"), + (to_column, "to_column"), + ): + err = _validate_identifier(value, role=role) + if err is not None: + return json.dumps({"success": False, "error": err}) + + sql = ( + "WITH from_distinct AS (" + f" SELECT DISTINCT {from_column} AS v FROM {from_table} " + f" WHERE {from_column} IS NOT NULL" + ")," + " to_distinct AS (" + f" SELECT DISTINCT {to_column} AS v FROM {to_table} " + f" WHERE {to_column} IS NOT NULL" + ")," + " inter AS (" + " SELECT v FROM from_distinct INTERSECT SELECT v FROM to_distinct" + ") " + "SELECT (SELECT COUNT(*) FROM from_distinct) AS from_distinct_count, " + " (SELECT COUNT(*) FROM to_distinct) AS to_distinct_count, " + " (SELECT COUNT(*) FROM inter) AS intersection_count" + ) + logger.debug("tool_column_value_overlap: SQL=%s", sql) + + rows, err = _run_query(ctx, sql, tool_name="tool_column_value_overlap") + if err is not None: + return json.dumps({"success": False, "error": err}) + if not rows: + return json.dumps( + {"success": False, "error": "overlap query returned no rows"} + ) + + row = rows[0] + from_distinct = int(row.get("from_distinct_count", 0) or 0) + to_distinct = int(row.get("to_distinct_count", 0) or 0) + intersection = int(row.get("intersection_count", 0) or 0) + + if from_distinct == 0: + result: Dict[str, Any] = { + "success": True, + "overlap_pct": 0.0, + "from_distinct_count": 0, + "to_distinct_count": to_distinct, + "intersection_count": 0, + "note": ( + f"{from_table}.{from_column} has zero distinct non-null values; " + "overlap_pct defaulted to 0.0 (no division by zero)." + ), + } + else: + result = { + "success": True, + "overlap_pct": intersection / from_distinct, + "from_distinct_count": from_distinct, + "to_distinct_count": to_distinct, + "intersection_count": intersection, + # Symmetric shape with the zero-denom branch: downstream consumers + # can read ``note`` unconditionally. + "note": "", + } + logger.info( + "tool_column_value_overlap: overlap_pct=%.4f (%d/%d)", + result["overlap_pct"], + intersection, + from_distinct, + ) + return json.dumps(result) + + +def tool_normalized_value_overlap( + ctx: ToolContext, + *, + from_table: str = "", + from_expr: str = "", + to_table: str = "", + to_expr: str = "", + **_kwargs, +) -> str: + """Like :func:`tool_column_value_overlap`, but each side is an arbitrary + scalar SQL *expression* rather than a bare column. + + This is the tool the Planner uses to PROVE a canonical-key normalization + works before committing it. When two tables that map to the same ontology + class have 0% raw-column overlap, the values are trust-local encodings of + the same key. The Planner proposes a normalization expression per table + (e.g. ``regexp_extract(EPISODE_ID, '([a-f0-9][a-f0-9-]+-preg-\\d+)', 1)``) + and calls this tool to confirm the expressions land in a common value + space (overlap_pct > 0). A still-zero overlap means the normalization is + wrong — fix it before submitting. + """ + logger.info( + "tool_normalized_value_overlap: %s[%s] ↔ %s[%s]", + from_table, + from_expr, + to_table, + to_expr, + ) + if not (from_table and from_expr and to_table and to_expr): + return json.dumps( + { + "success": False, + "error": "from_table, from_expr, to_table, to_expr are all required", + } + ) + + for value, role in ((from_table, "from_table"), (to_table, "to_table")): + err = _validate_identifier(value, role=role) + if err is not None: + return json.dumps({"success": False, "error": err}) + for value, role in ((from_expr, "from_expr"), (to_expr, "to_expr")): + err = _validate_safe_expression(value, role=role) + if err is not None: + return json.dumps({"success": False, "error": err}) + + sql = ( + "WITH from_distinct AS (" + f" SELECT DISTINCT {from_expr} AS v FROM {from_table} " + f" WHERE {from_expr} IS NOT NULL AND {from_expr} <> ''" + ")," + " to_distinct AS (" + f" SELECT DISTINCT {to_expr} AS v FROM {to_table} " + f" WHERE {to_expr} IS NOT NULL AND {to_expr} <> ''" + ")," + " inter AS (" + " SELECT v FROM from_distinct INTERSECT SELECT v FROM to_distinct" + ") " + "SELECT (SELECT COUNT(*) FROM from_distinct) AS from_distinct_count, " + " (SELECT COUNT(*) FROM to_distinct) AS to_distinct_count, " + " (SELECT COUNT(*) FROM inter) AS intersection_count" + ) + logger.debug("tool_normalized_value_overlap: SQL=%s", sql) + + rows, err = _run_query(ctx, sql, tool_name="tool_normalized_value_overlap") + if err is not None: + return json.dumps({"success": False, "error": err}) + if not rows: + return json.dumps( + {"success": False, "error": "overlap query returned no rows"} + ) + + row = rows[0] + from_distinct = int(row.get("from_distinct_count", 0) or 0) + to_distinct = int(row.get("to_distinct_count", 0) or 0) + intersection = int(row.get("intersection_count", 0) or 0) + + if from_distinct == 0: + result: Dict[str, Any] = { + "success": True, + "overlap_pct": 0.0, + "from_distinct_count": 0, + "to_distinct_count": to_distinct, + "intersection_count": 0, + "note": ( + f"{from_expr} over {from_table} produced zero distinct non-empty " + "values; the expression likely does not match the data — revise it." + ), + } + else: + result = { + "success": True, + "overlap_pct": intersection / from_distinct, + "from_distinct_count": from_distinct, + "to_distinct_count": to_distinct, + "intersection_count": intersection, + "note": "", + } + logger.info( + "tool_normalized_value_overlap: overlap_pct=%.4f (%d/%d)", + result["overlap_pct"], + intersection, + from_distinct, + ) + return json.dumps(result) + + +def tool_distinct_count( + ctx: ToolContext, *, full_name: str = "", column: str = "", **_kwargs +) -> str: + """Report row / distinct / null counts for ``full_name.column`` and + derive ``is_unique`` and ``is_complete`` flags. + + * ``is_unique = distinct_count == row_count - null_count`` — i.e. the + non-null subset has no duplicates. + * ``is_complete = null_count == 0`` — no missing values. + """ + logger.info("tool_distinct_count: %s.%s", full_name, column) + if not (full_name and column): + return json.dumps( + {"success": False, "error": "full_name and column are required"} + ) + + for value, role in ((full_name, "full_name"), (column, "column")): + err = _validate_identifier(value, role=role) + if err is not None: + return json.dumps({"success": False, "error": err}) + + sql = ( + f"SELECT COUNT(*) AS row_count, " + f" COUNT(DISTINCT {column}) AS distinct_count, " + f" COUNT(*) - COUNT({column}) AS null_count " + f"FROM {full_name}" + ) + logger.debug("tool_distinct_count: SQL=%s", sql) + + rows, err = _run_query(ctx, sql, tool_name="tool_distinct_count") + if err is not None: + return json.dumps({"success": False, "error": err}) + if not rows: + return json.dumps( + {"success": False, "error": "distinct_count query returned no rows"} + ) + + row = rows[0] + row_count = int(row.get("row_count", 0) or 0) + distinct_count = int(row.get("distinct_count", 0) or 0) + null_count = int(row.get("null_count", 0) or 0) + non_null_rows = row_count - null_count + + result = { + "success": True, + "row_count": row_count, + "distinct_count": distinct_count, + "null_count": null_count, + "is_unique": distinct_count == non_null_rows, + "is_complete": null_count == 0, + } + logger.info( + "tool_distinct_count: rows=%d, distinct=%d, nulls=%d, unique=%s, complete=%s", + row_count, + distinct_count, + null_count, + result["is_unique"], + result["is_complete"], + ) + return json.dumps(result) + + +def tool_submit_source_model( + ctx: ToolContext, *, model: Optional[dict] = None, **_kwargs +) -> str: + """Terminal Planner tool: validate ``model`` against + :class:`SourceModel` and stash the dataclass on ``ctx.source_model``. + + Only structural validity is checked here (does ``SourceModel.from_dict`` + succeed?). Semantic checks — e.g. coverage against the live ontology — + are the orchestrator's responsibility. + """ + # Local import to keep ``agents.tools`` importable without + # ``agents.agent_mapping_pge`` (avoids circular imports during pkg init). + from agents.agent_mapping_pge.contracts import SourceModel + + logger.info("tool_submit_source_model: validating candidate model") + if model is None or not isinstance(model, dict): + return json.dumps( + {"success": False, "error": "model must be a JSON object"} + ) + + try: + source_model = SourceModel.from_dict(model) + except (KeyError, TypeError, ValueError) as exc: + # ``KeyError`` for missing required fields; ``TypeError`` / ``ValueError`` + # for bad coercions (e.g. confidence not float-parseable). + logger.warning( + "tool_submit_source_model: validation failed: %s: %s", + type(exc).__name__, + exc, + ) + return json.dumps( + { + "success": False, + "error": f"SourceModel validation failed: {type(exc).__name__}: {exc}", + } + ) + + ctx.source_model = source_model + summary = { + "table_roles": len(source_model.table_roles), + "canonical_ids": len(source_model.canonical_ids), + "join_keys": len(source_model.join_keys), + "entity_order_len": len(source_model.mapping_plan.entity_order), + "relationship_order_len": len(source_model.mapping_plan.relationship_order), + } + logger.info("tool_submit_source_model: stored — %s", summary) + return json.dumps({"success": True, "summary": summary}) + + +# ===================================================== +# OpenAI function-calling definitions +# ===================================================== + + +SAMPLE_TABLE_DEF: dict = { + "type": "function", + "function": { + "name": "sample_table", + "description": ( + "Return up to N random sample rows from a table so you can see actual values " + "(not just column types). n defaults to 20 and is capped at 100." + ), + "parameters": { + "type": "object", + "properties": { + "full_name": { + "type": "string", + "description": "Fully-qualified table name (catalog.schema.table).", + }, + "n": { + "type": "integer", + "description": "Sample size (default 20, max 100).", + }, + }, + "required": ["full_name"], + }, + }, +} + + +COLUMN_VALUE_OVERLAP_DEF: dict = { + "type": "function", + "function": { + "name": "column_value_overlap", + "description": ( + "Compute the one-sided overlap |distinct(from) ∩ distinct(to)| / |distinct(from)|. " + "Use this to validate a candidate join key before committing it to the SourceModel." + ), + "parameters": { + "type": "object", + "properties": { + "from_table": { + "type": "string", + "description": "Fully-qualified source table.", + }, + "from_column": { + "type": "string", + "description": "Column on the source side (numerator denominator).", + }, + "to_table": { + "type": "string", + "description": "Fully-qualified target table.", + }, + "to_column": { + "type": "string", + "description": "Column on the target side.", + }, + }, + "required": ["from_table", "from_column", "to_table", "to_column"], + }, + }, +} + + +NORMALIZED_VALUE_OVERLAP_DEF: dict = { + "type": "function", + "function": { + "name": "normalized_value_overlap", + "description": ( + "Same overlap metric as column_value_overlap, but each side is a " + "scalar SQL EXPRESSION instead of a bare column. Use this to PROVE a " + "canonical-key normalization before committing it: when two tables " + "that map to the same ontology class have 0% raw-column overlap, " + "propose a normalization expression per table (e.g. " + "regexp_extract(EPISODE_ID, '([a-f0-9][a-f0-9-]+-preg-\\d+)', 1)) and " + "call this to confirm overlap_pct > 0. A still-zero result means the " + "expression is wrong — fix it before submit_source_model. Expressions " + "must be a single scalar (functions/literals/operators only); " + "subqueries and SQL keywords are rejected." + ), + "parameters": { + "type": "object", + "properties": { + "from_table": { + "type": "string", + "description": "Fully-qualified source table.", + }, + "from_expr": { + "type": "string", + "description": ( + "Scalar SQL expression over the source table that " + "produces the canonical key (e.g. a regexp_extract / " + "concat). Bare column names are also accepted." + ), + }, + "to_table": { + "type": "string", + "description": "Fully-qualified target table.", + }, + "to_expr": { + "type": "string", + "description": "Scalar SQL expression over the target table.", + }, + }, + "required": ["from_table", "from_expr", "to_table", "to_expr"], + }, + }, +} + + +DISTINCT_COUNT_DEF: dict = { + "type": "function", + "function": { + "name": "distinct_count", + "description": ( + "Report row_count / distinct_count / null_count for a column, with is_unique " + "and is_complete flags. Use this to vet a candidate canonical-ID column." + ), + "parameters": { + "type": "object", + "properties": { + "full_name": { + "type": "string", + "description": "Fully-qualified table name (catalog.schema.table).", + }, + "column": { + "type": "string", + "description": "Column to characterise.", + }, + }, + "required": ["full_name", "column"], + }, + }, +} + + +SUBMIT_SOURCE_MODEL_DEF: dict = { + "type": "function", + "function": { + "name": "submit_source_model", + "description": ( + "Terminal Planner tool. Submit the final SourceModel JSON (matching " + "SourceModel.to_dict() shape). Validates the structure and stores the " + "dataclass on the ToolContext for the Generator stage to consume." + ), + "parameters": { + "type": "object", + "properties": { + "model": { + "type": "object", + "description": ( + "JSON-encoded SourceModel with table_roles, canonical_ids, " + "join_keys, and mapping_plan." + ), + } + }, + "required": ["model"], + }, + }, +} + + +# ===================================================== +# Aggregate exports +# ===================================================== + + +PLANNER_TOOL_DEFINITIONS: List[dict] = [ + SAMPLE_TABLE_DEF, + COLUMN_VALUE_OVERLAP_DEF, + NORMALIZED_VALUE_OVERLAP_DEF, + DISTINCT_COUNT_DEF, + SUBMIT_SOURCE_MODEL_DEF, +] + + +PLANNER_TOOL_HANDLERS: Dict[str, Callable] = { + "sample_table": tool_sample_table, + "column_value_overlap": tool_column_value_overlap, + "normalized_value_overlap": tool_normalized_value_overlap, + "distinct_count": tool_distinct_count, + "submit_source_model": tool_submit_source_model, +} diff --git a/src/back/core/agents/AgentClient.py b/src/back/core/agents/AgentClient.py index 5d7ab263..3162c7a3 100644 --- a/src/back/core/agents/AgentClient.py +++ b/src/back/core/agents/AgentClient.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from agents.agent_owl_generator.engine import AgentResult from agents.agent_auto_assignment.engine import AgentResult as AutoAssignAgentResult + from agents.agent_mapping_pge.engine import AgentResult as MappingPGEAgentResult from agents.agent_auto_icon_assign.engine import ( AgentResult as IconAssignAgentResult, ) @@ -131,6 +132,67 @@ def run_auto_assignment( max_iterations=max_iterations, ) + def run_mapping_pge( + self, + *, + host: str, + token: str, + endpoint_name: str, + client: Any, + metadata: Any, + ontology: Any, + entity_mappings: Any, + relationship_mappings: Any, + documents: Any = None, + on_step: Optional[Callable] = None, + max_iterations: int = 10, + ) -> "MappingPGEAgentResult": + """Propose entity and relationship SQL mappings using the mapping PGE agent. + + This is the Planner–Generator–Evaluator (PGE) mapping engine + (``agents.agent_mapping_pge``). It shares the same call signature as + :meth:`run_auto_assignment` but drives a deterministic, coverage-checked + loop with a semantic critic and a structured run log. + + Args: + host: Databricks workspace host (with or without ``https://``). + token: Bearer token for the workspace APIs. + endpoint_name: Model serving endpoint name for the agent. + client: SQL client (typically :class:`~back.core.databricks.DatabricksClient`) + used to validate or sample queries against the configured warehouse. + metadata: Schema context (for example UC table metadata) for the agent. + ontology: Ontology dict describing classes and properties to map. + entity_mappings: Existing or partial entity mapping list for the agent + to refine or extend. + relationship_mappings: Existing or partial relationship mapping list. + documents: Optional list of document dicts (``name``, ``content``) for + grounding. + on_step: Optional progress callback invoked by the agent loop. + max_iterations: Upper bound on agent refinement iterations. + + Returns: + Structured result from ``agents.agent_mapping_pge`` describing + proposed mappings, per-item status, and PGE diagnostics. + + Raises: + Exception: Propagates any failure raised by ``run_agent``. + """ + from agents.agent_mapping_pge import run_agent + + return run_agent( + host=host, + token=token, + endpoint_name=endpoint_name, + client=client, + metadata=metadata, + ontology=ontology, + entity_mappings=entity_mappings, + relationship_mappings=relationship_mappings, + documents=documents, + on_step=on_step, + max_iterations=max_iterations, + ) + def run_icon_assign( self, *, diff --git a/src/back/objects/mapping/Mapping.py b/src/back/objects/mapping/Mapping.py index 13ed00fe..7e40ea35 100644 --- a/src/back/objects/mapping/Mapping.py +++ b/src/back/objects/mapping/Mapping.py @@ -27,7 +27,7 @@ _MAX_DOC_CHARS = 50_000 if TYPE_CHECKING: - from agents.agent_auto_assignment.engine import AgentResult as AutoAssignAgentResult + from agents.agent_mapping_pge.engine import AgentResult as AutoAssignAgentResult SINGLE_ITEM_MAX_ITERATIONS = 15 @@ -78,13 +78,18 @@ def auto_assign_with_agent( on_step: Optional[Callable[[str, int], None]] = None, max_iterations: Optional[int] = None, ) -> "AutoAssignAgentResult": - """Run ``agent_auto_assignment`` (blocking). + """Run the mapping-PGE agent (``agent_mapping_pge``) — blocking. + + Returns an :class:`AgentResult` with the standard ``entity_mappings`` + and ``relationship_mappings`` plus three PGE-specific extras + (``source_model``, ``mapping_evaluations``, ``mapping_run_log``) that + the caller can persist on the session. ``client`` is typically a :class:`~back.core.databricks.DatabricksClient` built with the domain warehouse. Call from a background thread when started from HTTP. """ - from agents.agent_auto_assignment import run_agent + from agents.agent_mapping_pge import run_agent return run_agent( host=host, @@ -165,6 +170,12 @@ def run_auto_assign_task( total_iterations = 0 total_usage = {"prompt_tokens": 0, "completion_tokens": 0} chunk_errors: List[str] = [] + # PGE-specific extras accumulated across chunks. Each chunk + # re-plans, so ``last_source_model`` reflects the most recent + # plan; per-item evaluations / run logs concatenate cleanly. + last_source_model: Optional[Dict[str, Any]] = None + merged_mapping_evaluations: Dict[str, Any] = {} + merged_mapping_run_log: List[Any] = [] for chunk_idx, chunk in enumerate(chunks): chunk_num = chunk_idx + 1 @@ -261,6 +272,19 @@ def on_step(msg: str, progress_pct: int = 0) -> None: for k in total_usage: total_usage[k] += agent_result.usage.get(k, 0) + # PGE extras — accumulate. The new engine returns these as + # dicts/lists (drop-in compatible). The legacy engine omitted + # them; ``getattr`` with defaults keeps us tolerant. + chunk_source_model = getattr(agent_result, "source_model", None) + if chunk_source_model: + last_source_model = chunk_source_model + chunk_evals = getattr(agent_result, "mapping_evaluations", None) or {} + if chunk_evals: + merged_mapping_evaluations.update(chunk_evals) + chunk_run_log = getattr(agent_result, "mapping_run_log", None) or [] + if chunk_run_log: + merged_mapping_run_log.extend(chunk_run_log) + e_done = len(entity_mapping_by_uri) r_done = len(rel_mapping_by_uri) @@ -345,6 +369,9 @@ def on_step(msg: str, progress_pct: int = 0) -> None: all_relationship_mappings, existing_entity_mappings=entity_mappings, existing_relationship_mappings=relationship_mappings, + source_model=last_source_model, + mapping_evaluations=merged_mapping_evaluations or None, + mapping_run_log=merged_mapping_run_log or None, ) message = f"Completed: {e_count} entities, {r_count} relationships mapped" @@ -449,6 +476,11 @@ def on_step(msg: str, progress_pct: int = 0) -> None: tm.fail_task(task.id, "Agent completed but produced no mapping") return + # PGE extras from this single-item run — passed through verbatim. + single_source_model = getattr(agent_result, "source_model", None) + single_evals = getattr(agent_result, "mapping_evaluations", None) or None + single_run_log = getattr(agent_result, "mapping_run_log", None) or None + if item_type == "entity": Mapping.save_mappings_to_session( session_id, @@ -456,6 +488,9 @@ def on_step(msg: str, progress_pct: int = 0) -> None: agent_result.entity_mappings, None, existing_entity_mappings=existing_entity_mappings, + source_model=single_source_model, + mapping_evaluations=single_evals, + mapping_run_log=single_run_log, ) else: Mapping.save_mappings_to_session( @@ -464,6 +499,9 @@ def on_step(msg: str, progress_pct: int = 0) -> None: None, agent_result.relationship_mappings, existing_relationship_mappings=existing_relationship_mappings, + source_model=single_source_model, + mapping_evaluations=single_evals, + mapping_run_log=single_run_log, ) tm.complete_task( @@ -918,6 +956,9 @@ def save_mappings_to_session( *, existing_entity_mappings: Optional[list] = None, existing_relationship_mappings: Optional[list] = None, + source_model: Optional[Dict[str, Any]] = None, + mapping_evaluations: Optional[Dict[str, Any]] = None, + mapping_run_log: Optional[List[Any]] = None, ) -> None: if not session_id: logger.warning("save_mappings_to_session: no session_id — skipping") @@ -986,6 +1027,21 @@ def save_mappings_to_session( else: assignment["relationships"] = relationship_mappings + # Mapping-PGE extras — persisted alongside the assignment so the + # UI (future work) and downstream observability can surface + # planner state, per-item evaluation reports, and the per-item + # attempt log without re-running the agent. + if source_model is not None: + assignment["source_model"] = source_model + if mapping_evaluations is not None: + merged_evals = dict(assignment.get("mapping_evaluations") or {}) + merged_evals.update(mapping_evaluations) + assignment["mapping_evaluations"] = merged_evals + if mapping_run_log is not None: + existing_log = list(assignment.get("mapping_run_log") or []) + existing_log.extend(mapping_run_log) + assignment["mapping_run_log"] = existing_log + domain_node = bucket.setdefault("domain", {}) domain_node["assignment_changed"] = True diff --git a/tests/agents/agent_mapping_pge/__init__.py b/tests/agents/agent_mapping_pge/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/agents/agent_mapping_pge/test_contracts.py b/tests/agents/agent_mapping_pge/test_contracts.py new file mode 100644 index 00000000..35498c15 --- /dev/null +++ b/tests/agents/agent_mapping_pge/test_contracts.py @@ -0,0 +1,126 @@ +"""Smoke tests for the PGE contracts. + +These are intentionally narrow — they only assert that every contract +dataclass round-trips cleanly through ``to_dict`` / ``from_dict`` so that +downstream sprints (Planner, Generator, orchestrator) can rely on +JSON-safe serialisation for MLflow artefacts and registry persistence. +""" + +import json + +from agents.agent_mapping_pge.contracts import ( + CanonicalId, + EvalFailure, + EvalReport, + JoinKey, + MappingPlan, + RetryState, + SkipItem, + SourceModel, + TableRole, + TableRoleCandidate, +) + + +def _roundtrip(obj): + """Serialise to dict -> JSON string -> dict -> reconstruct via from_dict.""" + cls = type(obj) + d = obj.to_dict() + encoded = json.dumps(d) + back = cls.from_dict(json.loads(encoded)) + return back, d + + +def test_source_model_roundtrip(): + sm = SourceModel( + table_roles=[ + TableRole( + table="cat.sch.mothers", + ontology_class_candidates=[ + TableRoleCandidate( + uri="http://ex.org#Mother", confidence=0.92, reason="row match" + ), + ], + ), + ], + canonical_ids=[ + CanonicalId( + ontology_class="http://ex.org#Mother", + canonical_column_per_table={"cat.sch.mothers": "nhs_number"}, + format_note="NHS number, 10 digits, no separators", + ) + ], + join_keys=[ + JoinKey( + from_ref="cat.sch.babies.mother_nhs", + to_ref="cat.sch.mothers.nhs_number", + confidence=0.88, + overlap_pct=0.97, + kind="same_trust_fk", + ) + ], + mapping_plan=MappingPlan( + entity_order=["http://ex.org#Mother", "http://ex.org#Baby"], + relationship_order=["http://ex.org#hasBaby"], + skip=[SkipItem(item="http://ex.org#Ghost", reason="no source table")], + ), + ) + back, d = _roundtrip(sm) + assert back.to_dict() == d + assert back.table_roles[0].table == "cat.sch.mothers" + assert back.canonical_ids[0].canonical_column_per_table["cat.sch.mothers"] == "nhs_number" + assert back.join_keys[0].kind == "same_trust_fk" + assert back.mapping_plan.skip[0].item == "http://ex.org#Ghost" + + +def test_eval_report_roundtrip(): + report = EvalReport( + status="FAIL", + stage="deterministic", + metrics={"row_count": 0}, + failures=[ + EvalFailure( + kind="structural", + check="row_count", + expected="> 0", + observed="0", + hint="fix the FROM clause", + ) + ], + bubble_to_planner=True, + ) + back, d = _roundtrip(report) + assert back.to_dict() == d + assert back.status == "FAIL" + assert back.failures[0].check == "row_count" + + +def test_retry_state_roundtrip_with_and_without_report(): + rs_empty = RetryState(item_uri="http://ex.org#Mother") + back, d = _roundtrip(rs_empty) + assert back.to_dict() == d + assert back.last_eval_report is None + + rs = RetryState( + item_uri="http://ex.org#Baby", + generator_attempts=2, + planner_reinvocations=1, + last_eval_report=EvalReport( + status="FAIL", + stage="deterministic", + failures=[ + EvalFailure( + kind="structural", + check="total_edges", + expected="> 0", + observed="0", + hint="fix join", + ) + ], + bubble_to_planner=True, + ), + ) + back, d = _roundtrip(rs) + assert back.to_dict() == d + assert back.last_eval_report is not None + assert back.last_eval_report.failures[0].check == "total_edges" diff --git a/tests/agents/agent_mapping_pge/test_coverage.py b/tests/agents/agent_mapping_pge/test_coverage.py new file mode 100644 index 00000000..1baa1b75 --- /dev/null +++ b/tests/agents/agent_mapping_pge/test_coverage.py @@ -0,0 +1,140 @@ +"""Tests for deterministic coverage enforcement + derived mappings. + +These lock in the invariant that broke the pipeline before: coverage is +computed from the ontology, NOT from the Planner's discretionary plan, so every +class and every relationship is always attempted. +""" + +from agents.agent_mapping_pge import coverage as cov +from agents.agent_mapping_pge.contracts import ( + CanonicalId, + MappingPlan, + SkipItem, + SourceModel, + TableRole, + TableRoleCandidate, +) + +# A miniature class hierarchy mirroring the maternity shape: +# Person (abstract) -> Patient (abstract) -> {Mother, Baby concrete} +# Encounter (abstract) -> {Visit concrete} +PERSON = "u#Person" +PATIENT = "u#Patient" +MOTHER = "u#Mother" +BABY = "u#Baby" +ENCOUNTER = "u#Encounter" +VISIT = "u#Visit" +ATTENDED = "u#attended" # Visit -> Mother + + +def _ontology() -> dict: + return { + "entities": [ + {"uri": PERSON, "name": "Person", "parent": "", "attributes": []}, + {"uri": PATIENT, "name": "Patient", "parent": "Person", + "attributes": [{"name": "nhsnumber"}]}, + {"uri": MOTHER, "name": "Mother", "parent": "Patient", + "attributes": [{"name": "nhsnumber"}, {"name": "postcode"}]}, + {"uri": BABY, "name": "Baby", "parent": "Patient", + "attributes": [{"name": "nhsnumber"}]}, + {"uri": ENCOUNTER, "name": "Encounter", "parent": "", "attributes": []}, + {"uri": VISIT, "name": "Visit", "parent": "Encounter", "attributes": []}, + ], + "relationships": [ + {"uri": ATTENDED, "name": "attended", "domain": VISIT, "range": MOTHER}, + ], + } + + +def _source_model(skip_some: bool = False) -> SourceModel: + # Planner only assigns tables to the concrete leaves + only plans a SUBSET. + roles = [ + TableRole(table="c.s.mother", ontology_class_candidates=[ + TableRoleCandidate(uri=MOTHER, confidence=0.9)]), + TableRole(table="c.s.baby", ontology_class_candidates=[ + TableRoleCandidate(uri=BABY, confidence=0.9)]), + TableRole(table="c.s.visit", ontology_class_candidates=[ + TableRoleCandidate(uri=VISIT, confidence=0.9)]), + ] + cids = [ + CanonicalId(ontology_class=MOTHER, canonical_column_per_table={"c.s.mother": "nhs"}), + CanonicalId(ontology_class=VISIT, canonical_column_per_table={"c.s.visit": "vid"}), + ] + plan = MappingPlan( + entity_order=[MOTHER], # deliberately incomplete + relationship_order=[], # deliberately empty + skip=[SkipItem(item=BABY, reason="planner skipped it")] if skip_some else [], + ) + return SourceModel(table_roles=roles, canonical_ids=cids, mapping_plan=plan) + + +def test_classify_abstract_vs_concrete(): + concrete, abstract = cov.classify(_ontology(), _source_model()) + assert abstract == {PERSON, PATIENT, ENCOUNTER} + assert {MOTHER, BABY, VISIT} <= concrete + + +def test_full_entity_order_covers_all_classes_abstracts_last(): + order = cov.full_entity_order(_ontology(), _source_model()) + assert set(order) == {PERSON, PATIENT, MOTHER, BABY, ENCOUNTER, VISIT} + # Every concrete leaf precedes its abstract ancestors. + assert order.index(MOTHER) < order.index(PATIENT) < order.index(PERSON) + assert order.index(VISIT) < order.index(ENCOUNTER) + + +def test_skip_list_does_not_reduce_coverage(): + # Even when the Planner skips Baby, full coverage still includes it. + order = cov.full_entity_order(_ontology(), _source_model(skip_some=True)) + assert BABY in order + + +def test_full_relationship_order_includes_all(): + order = cov.full_entity_order(_ontology(), _source_model()) + rels = cov.full_relationship_order(_ontology(), order, _source_model()) + assert rels == [ATTENDED] + + +def test_concrete_leaf_descendants(): + concrete, _ = cov.classify(_ontology(), _source_model()) + assert set(cov.concrete_leaf_descendants(PERSON, _ontology(), concrete)) == {MOTHER, BABY} + assert set(cov.concrete_leaf_descendants(PATIENT, _ontology(), concrete)) == {MOTHER, BABY} + assert set(cov.concrete_leaf_descendants(ENCOUNTER, _ontology(), concrete)) == {VISIT} + + +def test_build_abstract_union_mapping_reuses_subclass_sql(): + mother_em = { + "ontology_class": MOTHER, "id_column": "ID", + "sql_query": "SELECT nhs AS ID, nhs AS nhsnumber, pc AS postcode FROM c.s.mother", + "attribute_mappings": {"nhsnumber": "nhsnumber", "postcode": "postcode"}, + } + baby_em = { + "ontology_class": BABY, "id_column": "ID", + "sql_query": "SELECT CONCAT(nhs,'-baby') AS ID, nhs AS nhsnumber FROM c.s.baby", + "attribute_mappings": {"nhsnumber": "nhsnumber"}, + } + patient = next(e for e in _ontology()["entities"] if e["uri"] == PATIENT) + m = cov.build_abstract_union_mapping(PATIENT, patient, [mother_em, baby_em]) + assert m is not None + assert m["id_column"] == "ID" + assert m["derived"] == "abstract_union" + # Patient's own attribute (nhsnumber) is projected; both subclass SQLs reused. + assert "nhsnumber" in m["attribute_mappings"] + assert "UNION ALL" in m["sql_query"] + assert "c.s.mother" in m["sql_query"] and "c.s.baby" in m["sql_query"] + + +def test_build_abstract_union_mapping_none_when_no_subclasses(): + patient = next(e for e in _ontology()["entities"] if e["uri"] == PATIENT) + assert cov.build_abstract_union_mapping(PATIENT, patient, []) is None + + +def test_synthetic_endpoint_mapping_from_canonical_ids(): + em = cov.synthetic_endpoint_mapping(_source_model(), VISIT) + assert em is not None + assert em["id_column"] == "ID" + assert "c.s.visit" in em["sql_query"] + assert em["derived"] == "synthetic_endpoint" + + +def test_synthetic_endpoint_mapping_none_for_unknown_class(): + assert cov.synthetic_endpoint_mapping(_source_model(), "u#Nope") is None diff --git a/tests/agents/agent_mapping_pge/test_critic.py b/tests/agents/agent_mapping_pge/test_critic.py new file mode 100644 index 00000000..cad85d79 --- /dev/null +++ b/tests/agents/agent_mapping_pge/test_critic.py @@ -0,0 +1,684 @@ +"""Tests for the mapping-PGE Semantic Critic agent (Sprint 6). + +Mirrors the structure of ``test_relationship_generator.py``. The Critic is a +narrow tool-calling ReAct loop terminated by ``submit_evaluation``. These +tests exercise the loop's control flow with a *fake LLM* — a stub that +replaces ``call_serving_endpoint`` at module level and returns canned +responses on a per-call basis. + +No real HTTP, no real Databricks, no MLflow tracing. + +What we DO exercise: +* PASS verdict terminates immediately. +* FAIL with bubble_to_planner=False (column-level). +* FAIL with bubble_to_planner=True (table-level) — the bubble flag survives. +* PASS+bubble is demoted (matches build_report behaviour). +* FAIL with empty failures[] synthesises a generic semantic failure. +* Invalid status does NOT terminate — loop continues, accepts a valid retry. +* Text-only response → failure with "without submitting evaluation". +* Iteration-budget exhaustion → failure with "iteration budget". +* User prompt surfaces structural-stage metrics. +* User prompt for relationships includes domain/range sections. +""" + +import json +from typing import Any, Callable, Dict, List, Optional + +import pytest + +from agents.agent_mapping_pge.evaluator import critic as critic_mod +from agents.agent_mapping_pge.evaluator.critic import ( + CriticResult, + CriticStep, + run_critic, +) + + +# ===================================================== +# Fake LLM scaffolding +# ===================================================== + + +_ENTITY_URI = "http://ex.org/maternity#Mother" +_REL_URI = "http://ex.org/maternity#motherOf" + + +def _make_tool_call(name: str, arguments: dict, *, tc_id: str = "tc1") -> dict: + return { + "id": tc_id, + "type": "function", + "function": {"name": name, "arguments": json.dumps(arguments)}, + } + + +def _llm_response( + *, + tool_calls: Optional[List[dict]] = None, + content: Optional[str] = None, + finish_reason: str = "tool_calls", + usage: Optional[Dict[str, int]] = None, +) -> dict: + message: Dict[str, Any] = {"role": "assistant"} + if tool_calls: + message["tool_calls"] = tool_calls + if content is not None: + message["content"] = content + return { + "choices": [{"finish_reason": finish_reason, "message": message}], + "usage": usage or {"prompt_tokens": 10, "completion_tokens": 5}, + } + + +class FakeLLM: + def __init__(self, responses: List[dict]): + self.responses = list(responses) + self.calls = 0 + self.last_messages: Optional[List[dict]] = None + self.first_messages: Optional[List[dict]] = None + + def __call__(self, *args, **kwargs) -> dict: + self.calls += 1 + msgs: Optional[List[dict]] = None + if len(args) >= 4 and isinstance(args[3], list): + msgs = args[3] + elif "messages" in kwargs: + msgs = kwargs["messages"] + if msgs is not None: + snapshot = [dict(m) for m in msgs] + if self.first_messages is None: + self.first_messages = snapshot + self.last_messages = snapshot + + if not self.responses: + raise AssertionError( + f"FakeLLM: ran out of canned responses on call #{self.calls}" + ) + return self.responses.pop(0) + + +class CyclingFakeLLM: + """Like FakeLLM but cycles through a fixed list forever.""" + + def __init__(self, responses: List[dict]): + self.responses = list(responses) + self.calls = 0 + + def __call__(self, *args, **kwargs) -> dict: + resp = self.responses[self.calls % len(self.responses)] + self.calls += 1 + return resp + + +@pytest.fixture +def no_sleep(monkeypatch): + """Neutralise the 3-second inter-iteration delay so tests run fast.""" + monkeypatch.setattr(critic_mod.time, "sleep", lambda *_a, **_k: None) + + +def _patch_llm(monkeypatch, fake: Callable[..., dict]) -> None: + monkeypatch.setattr(critic_mod, "call_serving_endpoint", fake) + + +# ===================================================== +# Fixtures +# ===================================================== + + +def _entity_definition() -> dict: + return { + "uri": _ENTITY_URI, + "label": "Mother", + "name": "Mother", + "comment": "A pregnant woman in the maternity dataset.", + "attributes": [ + {"name": "nhsNumber", "type": "string"}, + {"name": "dateOfBirth", "type": "date"}, + ], + } + + +def _relationship_definition() -> dict: + return { + "uri": _REL_URI, + "label": "motherOf", + "name": "motherOf", + "comment": "Links a Mother to each of her babies.", + "domain": _ENTITY_URI, + "range": "http://ex.org/maternity#Baby", + } + + +def _entity_submitted_mapping() -> dict: + return { + "ontology_class": _ENTITY_URI, + "class_name": "Mother", + "sql_query": "SELECT nhs_number AS ID, nhs_number AS Label FROM cat.sch.mothers WHERE nhs_number IS NOT NULL", + "id_column": "nhs_number", + "label_column": "nhs_number", + "attribute_mappings": {"nhsNumber": "nhs_number"}, + "unmapped_attributes": [ + {"name": "dateOfBirth", "reason": "column absent from this table"} + ], + } + + +def _relationship_submitted_mapping() -> dict: + return { + "property": _REL_URI, + "property_name": "motherOf", + "sql_query": "SELECT mother_nhs_number AS source_id, baby_id AS target_id FROM cat.sch.babies", + "source_id_column": "nhs_number", + "target_id_column": "baby_id", + "domain": _ENTITY_URI, + "range_class": "http://ex.org/maternity#Baby", + } + + +def _source_model_slice() -> dict: + return { + "candidate_tables": [ + { + "table": "cat.sch.mothers", + "confidence": 0.9, + "reason": "row per mother, nhs_number as PK", + } + ], + "canonical_id": { + "canonical_column_per_table": {"cat.sch.mothers": "nhs_number"}, + "format_note": "10-digit NHS number", + }, + } + + +def _stage1_metrics(**overrides) -> dict: + base = { + "row_count": 100, + "distinct_ids": 100, + "null_ids": 0, + } + base.update(overrides) + return base + + +def _valid_pass_submit() -> dict: + return { + "status": "PASS", + "failures": [], + "bubble_to_planner": False, + "reasoning": "Sampled values match the Mother concept; column semantics OK.", + } + + +def _valid_fail_column_submit() -> dict: + return { + "status": "FAIL", + "failures": [ + { + "check": "column_semantics", + "expected": "delivery date", + "observed": "appointment_date is a booking date", + "hint": "Use `delivery_dttm` instead of `appointment_date`.", + } + ], + "bubble_to_planner": False, + "reasoning": "Wrong column within the right table.", + } + + +def _valid_fail_table_submit() -> dict: + return { + "status": "FAIL", + "failures": [ + { + "check": "table_selection", + "expected": "labour_delivery", + "observed": "antenatal_visits", + "hint": "Switch to `labour_delivery` table for the Delivery class.", + } + ], + "bubble_to_planner": True, + "reasoning": "Wrong table chosen — bubble to Planner.", + } + + +def _run_entity_critic( + fake: Callable[..., dict], + *, + max_iterations: int = 6, + item_kind: str = "entity", + item_uri: str = _ENTITY_URI, + item_definition: Optional[dict] = None, + submitted_mapping: Optional[dict] = None, + stage1_metrics: Optional[dict] = None, + client: Any = None, +) -> CriticResult: + return run_critic( + host="https://x", + token="t", + endpoint_name="ep", + client=client, + item_kind=item_kind, + item_uri=item_uri, + item_definition=item_definition + if item_definition is not None + else _entity_definition(), + submitted_mapping=submitted_mapping + if submitted_mapping is not None + else _entity_submitted_mapping(), + source_model_slice=_source_model_slice(), + stage1_metrics=stage1_metrics + if stage1_metrics is not None + else _stage1_metrics(), + max_iterations=max_iterations, + ) + + +# ===================================================== +# 1. PASS verdict terminates immediately +# ===================================================== + + +def test_pass_verdict(monkeypatch, no_sleep): + """First LLM turn submits PASS → success=True, status=PASS, iterations=1.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call("submit_evaluation", _valid_pass_submit()) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = _run_entity_critic(fake) + + assert isinstance(result, CriticResult) + assert result.success is True + assert result.iterations == 1 + assert result.report is not None + assert result.report.status == "PASS" + assert result.report.stage == "semantic" + assert result.report.failures == [] + assert result.report.bubble_to_planner is False + assert result.error == "" + # Step recording: one tool_call + one tool_result. + assert [s.step_type for s in result.steps] == ["tool_call", "tool_result"] + assert result.steps[0].tool_name == "submit_evaluation" + + +# ===================================================== +# 2. FAIL with bubble_to_planner=False (column-level) +# ===================================================== + + +def test_fail_column_level(monkeypatch, no_sleep): + """status=FAIL, bubble_to_planner=False → column-level failure preserved.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call("submit_evaluation", _valid_fail_column_submit()) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = _run_entity_critic(fake) + + assert result.success is True + assert result.report is not None + assert result.report.status == "FAIL" + assert result.report.bubble_to_planner is False + assert len(result.report.failures) == 1 + failure = result.report.failures[0] + assert failure.kind == "semantic" + assert failure.check == "column_semantics" + assert "delivery_dttm" in failure.hint + + +# ===================================================== +# 3. FAIL with bubble_to_planner=True (table-level) +# ===================================================== + + +def test_fail_table_level_bubbles(monkeypatch, no_sleep): + """status=FAIL, bubble_to_planner=True → bubble flag preserved on report.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call("submit_evaluation", _valid_fail_table_submit()) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = _run_entity_critic(fake) + + assert result.success is True + assert result.report is not None + assert result.report.status == "FAIL" + assert result.report.bubble_to_planner is True + assert len(result.report.failures) == 1 + failure = result.report.failures[0] + assert failure.check == "table_selection" + assert "labour_delivery" in failure.hint + + +# ===================================================== +# 4. PASS with bubble_to_planner=True is demoted +# ===================================================== + + +def test_demotes_pass_with_bubble(monkeypatch, no_sleep): + """A PASS verdict that asks to bubble is demoted to bubble=False.""" + bad_pass = _valid_pass_submit() + bad_pass["bubble_to_planner"] = True + + fake = FakeLLM( + [ + _llm_response( + tool_calls=[_make_tool_call("submit_evaluation", bad_pass)] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = _run_entity_critic(fake) + + assert result.success is True + assert result.report is not None + assert result.report.status == "PASS" + # The bubble flag must have been demoted. + assert result.report.bubble_to_planner is False + + +# ===================================================== +# 5. FAIL with empty failures[] synthesises one +# ===================================================== + + +def test_fail_without_failures_synthesises_one(monkeypatch, no_sleep): + """status=FAIL with empty failures[] gets a generic semantic failure + synthesised so the report stays coherent.""" + fail_no_failures = { + "status": "FAIL", + "failures": [], + "bubble_to_planner": False, + "reasoning": "Something is off but I can't pinpoint it.", + } + + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call("submit_evaluation", fail_no_failures) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = _run_entity_critic(fake) + + assert result.success is True + assert result.report is not None + assert result.report.status == "FAIL" + assert len(result.report.failures) == 1 + f = result.report.failures[0] + assert f.kind == "semantic" + assert f.check == "semantic_audit" + # The reasoning is folded into the synthetic failure's hint when present. + assert "Something is off" in f.hint + + +# ===================================================== +# 6. Invalid status does NOT terminate — agent retries +# ===================================================== + + +def test_invalid_status_rejected(monkeypatch, no_sleep): + """A submit with status='UNKNOWN' must NOT terminate the loop; the + Critic must keep going and a follow-up submit with a valid status + should succeed. + """ + fake = FakeLLM( + [ + # Turn 1: invalid status → handler returns success=False, loop continues. + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_evaluation", + { + "status": "UNKNOWN", + "failures": [], + "bubble_to_planner": False, + "reasoning": "n/a", + }, + tc_id="bad", + ) + ] + ), + # Turn 2: valid PASS submit → terminates. + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_evaluation", + _valid_pass_submit(), + tc_id="good", + ) + ] + ), + ] + ) + _patch_llm(monkeypatch, fake) + + result = _run_entity_critic(fake) + + assert result.success is True + assert result.iterations == 2 + assert result.report is not None + assert result.report.status == "PASS" + # Both submit attempts left tool_call + tool_result steps (4 total). + assert len(result.steps) == 4 + + # The corrective tool message on the 2nd LLM call must contain the + # "invalid status" error so the LLM sees why its first attempt failed. + assert fake.last_messages is not None + tool_messages = [m for m in fake.last_messages if m.get("role") == "tool"] + assert tool_messages, "expected at least one tool message on the 2nd call" + first_tool_msg = tool_messages[0].get("content", "") + parsed = json.loads(first_tool_msg) + assert parsed.get("success") is False + assert "invalid status" in parsed.get("error", "") + + +# ===================================================== +# 7. Text without terminal call → failure +# ===================================================== + + +def test_text_without_terminal_fails(monkeypatch, no_sleep): + """A plain-text response is treated as failure — the Critic must + terminate via submit_evaluation. + """ + fake = FakeLLM( + [_llm_response(content="I am thinking…", finish_reason="stop")] + ) + _patch_llm(monkeypatch, fake) + + result = _run_entity_critic(fake) + + assert result.success is False + assert result.iterations == 1 + assert result.report is None + assert "without submitting evaluation" in result.error + assert any(s.step_type == "output" for s in result.steps) + + +# ===================================================== +# 8. Iteration-budget exhaustion → failure +# ===================================================== + + +def test_exhausts_budget(monkeypatch, no_sleep): + """Endless sample_table calls with max_iterations=3 → fail with + ``iteration budget`` and three iterations of steps recorded.""" + fake = CyclingFakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "sample_table", + {"full_name": "cat.sch.mothers"}, + tc_id="probe", + ) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + class FakeClient: + def execute_query(self, sql): + return [{"nhs_number": "1234567890"}] + + result = _run_entity_critic(fake, max_iterations=3, client=FakeClient()) + + assert result.success is False + assert result.iterations == 3 + assert result.report is None + assert "iteration budget" in result.error + # 3 iterations × (tool_call + tool_result) = 6 steps. + assert len(result.steps) == 6 + + +# ===================================================== +# 9. User prompt surfaces stage1 metrics +# ===================================================== + + +def test_user_prompt_includes_stage1_metrics(monkeypatch, no_sleep): + """The first LLM call's user message must contain stage1 metric values + so the Critic sees the structural context.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call("submit_evaluation", _valid_pass_submit()) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + _run_entity_critic( + fake, + stage1_metrics={"row_count": 1234, "distinct_ids": 1234, "null_ids": 0}, + ) + + assert fake.first_messages is not None + assert fake.first_messages[0]["role"] == "system" + assert fake.first_messages[1]["role"] == "user" + user_content = fake.first_messages[1]["content"] + assert "1234" in user_content + assert "STRUCTURAL CHECK METRICS" in user_content + + +# ===================================================== +# 10. Relationship audit surfaces domain/range +# ===================================================== + + +def test_user_prompt_distinguishes_entity_vs_relationship(monkeypatch, no_sleep): + """When item_kind='relationship', the user prompt must include the + 'domain' and 'range' lines that an entity prompt would not have.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call("submit_evaluation", _valid_pass_submit()) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + run_critic( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + item_kind="relationship", + item_uri=_REL_URI, + item_definition=_relationship_definition(), + submitted_mapping=_relationship_submitted_mapping(), + source_model_slice=_source_model_slice(), + stage1_metrics=_stage1_metrics(), + ) + + assert fake.first_messages is not None + user_content = fake.first_messages[1]["content"] + # The kind is surfaced explicitly. + assert "relationship" in user_content + # The relationship-specific domain/range sections appear. + assert "domain:" in user_content + assert "range:" in user_content + # And it is framed as a relationship submitted mapping. + assert "SUBMITTED MAPPING (relationship)" in user_content + # The relationship endpoint columns are surfaced too. + assert "source_id_column" in user_content + assert "target_id_column" in user_content + + +# ===================================================== +# 11. Step recording invariants +# ===================================================== + + +def test_records_steps(monkeypatch, no_sleep): + """Every tool-calling iteration produces one ``tool_call`` step + immediately followed by one ``tool_result`` step with the same tool_name.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "sample_table", + {"full_name": "cat.sch.mothers"}, + tc_id="a", + ) + ] + ), + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_evaluation", + _valid_pass_submit(), + tc_id="b", + ) + ] + ), + ] + ) + _patch_llm(monkeypatch, fake) + + class FakeClient: + def execute_query(self, sql): + return [{"nhs_number": "1234567890"}] + + result = _run_entity_critic(fake, client=FakeClient()) + + assert result.success is True + assert len(result.steps) % 2 == 0 + for i in range(0, len(result.steps), 2): + call_step = result.steps[i] + result_step = result.steps[i + 1] + assert call_step.step_type == "tool_call" + assert result_step.step_type == "tool_result" + assert call_step.tool_name == result_step.tool_name + assert isinstance(call_step, CriticStep) + assert isinstance(result_step, CriticStep) diff --git a/tests/agents/agent_mapping_pge/test_deterministic_evaluator.py b/tests/agents/agent_mapping_pge/test_deterministic_evaluator.py new file mode 100644 index 00000000..94ec3f35 --- /dev/null +++ b/tests/agents/agent_mapping_pge/test_deterministic_evaluator.py @@ -0,0 +1,916 @@ +"""Tests for the deterministic (stage-1) evaluator of the mapping PGE pipeline. + +The evaluator is a pure function: it takes a submitted mapping plus an +injectable ``execute_sql_fn`` and returns an ``EvalReport`` summarising +structural failures. No LLM, no Databricks connection. + +``execute_sql_fn`` contract (for the evaluator): + def execute_sql_fn(sql: str) -> dict +returning:: + + {"columns": [...], "rows": [{col: value, ...}, ...]} + +This is the full result set — not the 3-row sample emitted by +``agents.tools.sql.tool_execute_sql``. The PGE orchestrator (Sprint 7) is +responsible for wiring a runner that yields full rows. +""" + +from typing import Dict, List + +import pytest + +from agents.agent_mapping_pge.contracts import EvalFailure, EvalReport +from agents.agent_mapping_pge.evaluator.deterministic import ( + evaluate_entity_mapping, + evaluate_relationship_mapping, +) +from agents.agent_mapping_pge.evaluator.report import build_report + + +# ===================================================== +# Fixtures +# ===================================================== + + +MOTHER_CLASS = { + "uri": "http://ex.org/maternity#Mother", + "name": "Mother", + "attributes": [ + {"name": "firstName"}, + {"name": "lastName"}, + {"name": "nhsNumber"}, + ], +} + +BABY_CLASS = { + "uri": "http://ex.org/maternity#Baby", + "name": "Baby", + "attributes": [ + {"name": "birthWeight"}, + ], +} + + +def _mother_mapping(*, attribute_mappings=None, unmapped_attributes=None): + mapping = { + "ontology_class": MOTHER_CLASS["uri"], + "class_name": "Mother", + "sql_query": "SELECT nhs_number AS ID, full_name AS Label, first_name, last_name, nhs_number FROM cat.sch.mothers", + "id_column": "ID", + "label_column": "Label", + "attribute_mappings": attribute_mappings + if attribute_mappings is not None + else { + "firstName": "first_name", + "lastName": "last_name", + "nhsNumber": "nhs_number", + }, + } + if unmapped_attributes is not None: + mapping["unmapped_attributes"] = unmapped_attributes + return mapping + + +def _baby_mapping(): + return { + "ontology_class": BABY_CLASS["uri"], + "class_name": "Baby", + "sql_query": "SELECT baby_id AS ID, baby_id AS Label, birth_weight FROM cat.sch.babies", + "id_column": "ID", + "label_column": "Label", + "attribute_mappings": {"birthWeight": "birth_weight"}, + } + + +def _mother_to_baby_relationship(): + return { + "property": "http://ex.org/maternity#hasBaby", + "property_name": "hasBaby", + "sql_query": ( + "SELECT mother_nhs AS source_id, baby_id AS target_id " + "FROM cat.sch.babies" + ), + "source_id_column": "source_id", + "target_id_column": "target_id", + "source_class": MOTHER_CLASS["uri"], + "target_class": BABY_CLASS["uri"], + } + + +def _make_sql_fn(table: dict): + """Return an execute_sql_fn closure that routes by SQL substring. + + ``table`` maps a unique substring -> {"columns": [...], "rows": [...]}. + """ + + def fn(sql: str) -> dict: + for needle, payload in table.items(): + if needle in sql: + return payload + raise AssertionError(f"unexpected SQL in test: {sql}") + + return fn + + +# ===================================================== +# Entity evaluator +# ===================================================== + + +class TestEvaluateEntityMapping: + def test_pass_happy_path(self): + mapping = _mother_mapping() + sql_fn = _make_sql_fn( + { + "mothers": { + "columns": ["ID", "Label", "first_name", "last_name", "nhs_number"], + "rows": [ + { + "ID": "NHS-001", + "Label": "Alice Smith", + "first_name": "Alice", + "last_name": "Smith", + "nhs_number": "NHS-001", + }, + { + "ID": "NHS-002", + "Label": "Bob Jones", + "first_name": "Bob", + "last_name": "Jones", + "nhs_number": "NHS-002", + }, + ], + } + } + ) + + report = evaluate_entity_mapping( + mapping=mapping, + ontology_class=MOTHER_CLASS, + execute_sql_fn=sql_fn, + ) + + assert isinstance(report, EvalReport) + assert report.status == "PASS" + assert report.stage == "deterministic" + assert report.failures == [] + assert report.bubble_to_planner is False + assert report.metrics["row_count"] == 2 + assert report.metrics["distinct_id_count"] == 2 + assert report.metrics["null_id_count"] == 0 + assert report.metrics["unmapped_attribute_pct"] == 0.0 + + def test_fail_row_count_zero_bubbles_to_planner(self): + mapping = _mother_mapping() + sql_fn = _make_sql_fn( + {"mothers": {"columns": ["ID", "Label"], "rows": []}} + ) + + report = evaluate_entity_mapping( + mapping=mapping, + ontology_class=MOTHER_CLASS, + execute_sql_fn=sql_fn, + ) + + assert report.status == "FAIL" + assert report.bubble_to_planner is True + check_names = [f.check for f in report.failures] + assert "row_count" in check_names + + def test_fail_duplicate_ids(self): + mapping = _mother_mapping() + sql_fn = _make_sql_fn( + { + "mothers": { + "columns": ["ID", "Label", "first_name", "last_name", "nhs_number"], + "rows": [ + { + "ID": "NHS-001", + "Label": "Alice", + "first_name": "Alice", + "last_name": "Smith", + "nhs_number": "NHS-001", + }, + { + "ID": "NHS-001", + "Label": "Alice dup", + "first_name": "Alice", + "last_name": "Smith", + "nhs_number": "NHS-001", + }, + ], + } + } + ) + + report = evaluate_entity_mapping( + mapping=mapping, + ontology_class=MOTHER_CLASS, + execute_sql_fn=sql_fn, + ) + + assert report.status == "FAIL" + assert report.bubble_to_planner is False + check_names = [f.check for f in report.failures] + assert "distinct_id_count" in check_names + + def test_fail_null_ids(self): + mapping = _mother_mapping() + sql_fn = _make_sql_fn( + { + "mothers": { + "columns": ["ID", "Label", "first_name", "last_name", "nhs_number"], + "rows": [ + { + "ID": None, + "Label": "Alice", + "first_name": "Alice", + "last_name": "Smith", + "nhs_number": None, + }, + { + "ID": "NHS-002", + "Label": "Bob", + "first_name": "Bob", + "last_name": "Jones", + "nhs_number": "NHS-002", + }, + ], + } + } + ) + + report = evaluate_entity_mapping( + mapping=mapping, + ontology_class=MOTHER_CLASS, + execute_sql_fn=sql_fn, + ) + + assert report.status == "FAIL" + check_names = [f.check for f in report.failures] + assert "null_id_count" in check_names + + def test_fail_unmapped_attribute(self): + # Omit lastName from attribute_mappings, no unmapped_attributes list. + mapping = _mother_mapping( + attribute_mappings={ + "firstName": "first_name", + "nhsNumber": "nhs_number", + }, + ) + sql_fn = _make_sql_fn( + { + "mothers": { + "columns": ["ID", "Label", "first_name", "last_name", "nhs_number"], + "rows": [ + { + "ID": "NHS-001", + "Label": "Alice", + "first_name": "Alice", + "last_name": "Smith", + "nhs_number": "NHS-001", + }, + ], + } + } + ) + + report = evaluate_entity_mapping( + mapping=mapping, + ontology_class=MOTHER_CLASS, + execute_sql_fn=sql_fn, + ) + + assert report.status == "FAIL" + check_names = [f.check for f in report.failures] + assert "unmapped_attribute_pct" in check_names + # 1 of 3 attributes missing -> ~0.333 + assert report.metrics["unmapped_attribute_pct"] == pytest.approx(1 / 3) + + def test_pass_when_unmapped_attribute_is_declared(self): + mapping = _mother_mapping( + attribute_mappings={ + "firstName": "first_name", + "nhsNumber": "nhs_number", + }, + unmapped_attributes=["lastName"], + ) + sql_fn = _make_sql_fn( + { + "mothers": { + "columns": ["ID", "Label", "first_name", "last_name", "nhs_number"], + "rows": [ + { + "ID": "NHS-001", + "Label": "Alice", + "first_name": "Alice", + "last_name": "Smith", + "nhs_number": "NHS-001", + }, + ], + } + } + ) + + report = evaluate_entity_mapping( + mapping=mapping, + ontology_class=MOTHER_CLASS, + execute_sql_fn=sql_fn, + ) + + assert report.status == "PASS" + assert report.metrics["unmapped_attribute_pct"] == 0.0 + + def test_pass_when_unmapped_attribute_is_declared_as_dict(self): + """The Generator may emit unmapped_attributes as [{name, reason}, ...]. + Hashing dicts would crash the evaluator — names must be extracted.""" + mapping = _mother_mapping( + attribute_mappings={ + "firstName": "first_name", + "nhsNumber": "nhs_number", + }, + unmapped_attributes=[ + {"name": "lastName", "reason": "no source column"} + ], + ) + sql_fn = _make_sql_fn( + { + "mothers": { + "columns": ["ID", "Label", "first_name", "last_name", "nhs_number"], + "rows": [ + { + "ID": "NHS-001", + "Label": "Alice", + "first_name": "Alice", + "last_name": "Smith", + "nhs_number": "NHS-001", + }, + ], + } + } + ) + + report = evaluate_entity_mapping( + mapping=mapping, + ontology_class=MOTHER_CLASS, + execute_sql_fn=sql_fn, + ) + + assert report.status == "PASS" + assert report.metrics["unmapped_attribute_pct"] == 0.0 + + def test_report_is_json_serialisable(self): + mapping = _mother_mapping() + sql_fn = _make_sql_fn( + { + "mothers": { + "columns": ["ID", "Label", "first_name", "last_name", "nhs_number"], + "rows": [ + { + "ID": "NHS-001", + "Label": "Alice", + "first_name": "Alice", + "last_name": "Smith", + "nhs_number": "NHS-001", + }, + ], + } + } + ) + + report = evaluate_entity_mapping( + mapping=mapping, + ontology_class=MOTHER_CLASS, + execute_sql_fn=sql_fn, + ) + d = report.to_dict() + assert d["status"] == "PASS" + assert d["stage"] == "deterministic" + assert isinstance(d["metrics"], dict) + assert isinstance(d["failures"], list) + + def test_sql_execution_error_becomes_fail_not_crash(self): + """A mapping whose SQL parses but fails at runtime (e.g. a UNION + type mismatch) must yield a FAIL report with the error as a hint — + never propagate and crash the agent run. + """ + mapping = _mother_mapping() + + def boom(sql: str) -> dict: + raise RuntimeError( + "[CAST_INVALID_INPUT] The value 'x-preg-1-baby' of the type " + '"STRING" cannot be cast to "BIGINT"' + ) + + report = evaluate_entity_mapping( + mapping=mapping, + ontology_class=MOTHER_CLASS, + execute_sql_fn=boom, + ) + + assert report.status == "FAIL" + # A runtime SQL error is the Generator's to fix, not a re-plan trigger. + assert report.bubble_to_planner is False + checks = [f.check for f in report.failures] + assert "sql_execution" in checks + # The underlying DB error is surfaced for the generator to act on. + assert "CAST_INVALID_INPUT" in report.metrics.get("sql_error", "") + # Report must remain JSON-serialisable. + import json as _json + + _json.dumps(report.to_dict()) + + +# ===================================================== +# Relationship evaluator +# ===================================================== + + +def _entity_rows(ids): + return { + "columns": ["ID", "Label"], + "rows": [{"ID": i, "Label": i} for i in ids], + } + + +class TestEvaluateRelationshipMapping: + def test_pass_happy_path(self): + rel = _mother_to_baby_relationship() + sql_fn = _make_sql_fn( + { + # Relationship edges + "source_id": { + "columns": ["source_id", "target_id"], + "rows": [ + {"source_id": "NHS-001", "target_id": "B-1"}, + {"source_id": "NHS-002", "target_id": "B-2"}, + ], + }, + # Source entity universe + "mothers": _entity_rows(["NHS-001", "NHS-002", "NHS-003"]), + # Target entity universe + "babies": _entity_rows(["B-1", "B-2", "B-3"]), + } + ) + + report = evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=sql_fn, + ) + + assert report.status == "PASS" + assert report.bubble_to_planner is False + assert report.metrics["total_edges"] == 2 + assert report.metrics["dangling_source_pct"] == 0.0 + assert report.metrics["dangling_target_pct"] == 0.0 + + def test_sql_execution_error_becomes_fail_not_crash(self): + """A relationship (or its endpoint-universe) SQL that errors at + runtime must yield a FAIL report, not crash the agent run. + """ + rel = _mother_to_baby_relationship() + + def boom(sql: str) -> dict: + raise RuntimeError("[UNRESOLVED_COLUMN] cannot resolve `target_id`") + + report = evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=boom, + ) + + assert report.status == "FAIL" + assert report.bubble_to_planner is False + assert "sql_execution" in [f.check for f in report.failures] + assert "UNRESOLVED_COLUMN" in report.metrics.get("sql_error", "") + + def test_fail_47_pct_dangling_source_bubbles(self): + rel = _mother_to_baby_relationship() + # 100 edges, 47 source_ids unknown to source universe. + edge_rows = [ + {"source_id": f"NHS-{i:03d}", "target_id": f"B-{i}"} + for i in range(1, 101) + ] + # Only NHS-001..NHS-053 exist as mothers. + mother_ids = [f"NHS-{i:03d}" for i in range(1, 54)] + baby_ids = [f"B-{i}" for i in range(1, 201)] + + sql_fn = _make_sql_fn( + { + "source_id": { + "columns": ["source_id", "target_id"], + "rows": edge_rows, + }, + "mothers": _entity_rows(mother_ids), + "babies": _entity_rows(baby_ids), + } + ) + + report = evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=sql_fn, + ) + + assert report.status == "FAIL" + assert report.bubble_to_planner is False # 0.47 < 0.5 threshold + check_names = [f.check for f in report.failures] + assert "dangling_source_pct" in check_names + assert report.metrics["dangling_source_pct"] == pytest.approx(0.47) + + def test_fail_above_50_pct_dangling_source_bubbles_to_planner(self): + rel = _mother_to_baby_relationship() + edge_rows = [ + {"source_id": f"NHS-{i:03d}", "target_id": f"B-{i}"} + for i in range(1, 101) + ] + # Only NHS-001..NHS-040 are known mothers -> 60% dangling + mother_ids = [f"NHS-{i:03d}" for i in range(1, 41)] + baby_ids = [f"B-{i}" for i in range(1, 201)] + + sql_fn = _make_sql_fn( + { + "source_id": { + "columns": ["source_id", "target_id"], + "rows": edge_rows, + }, + "mothers": _entity_rows(mother_ids), + "babies": _entity_rows(baby_ids), + } + ) + + report = evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=sql_fn, + ) + + assert report.status == "FAIL" + assert report.bubble_to_planner is True + + def test_pass_3_pct_dangling_source_under_threshold(self): + rel = _mother_to_baby_relationship() + # 100 edges, only 3 source ids not in mother universe -> 3%. + edge_rows = [ + {"source_id": f"NHS-{i:03d}", "target_id": f"B-{i}"} + for i in range(1, 101) + ] + mother_ids = [f"NHS-{i:03d}" for i in range(1, 98)] # 97 known, 3 dangling + baby_ids = [f"B-{i}" for i in range(1, 201)] + + sql_fn = _make_sql_fn( + { + "source_id": { + "columns": ["source_id", "target_id"], + "rows": edge_rows, + }, + "mothers": _entity_rows(mother_ids), + "babies": _entity_rows(baby_ids), + } + ) + + report = evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=sql_fn, + ) + + assert report.status == "PASS" + assert report.bubble_to_planner is False + assert report.metrics["dangling_source_pct"] == pytest.approx(0.03) + + def test_fail_zero_edges_bubbles_to_planner(self): + rel = _mother_to_baby_relationship() + sql_fn = _make_sql_fn( + { + "source_id": {"columns": ["source_id", "target_id"], "rows": []}, + "mothers": _entity_rows(["NHS-001"]), + "babies": _entity_rows(["B-1"]), + } + ) + + report = evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=sql_fn, + ) + + assert report.status == "FAIL" + assert report.bubble_to_planner is True + check_names = [f.check for f in report.failures] + assert "total_edges" in check_names + + def test_cross_source_band_fail_when_outside(self): + rel = _mother_to_baby_relationship() + # 100 edges, all source ids in mother universe. + edge_rows = [ + {"source_id": f"NHS-{i:03d}", "target_id": f"B-{i}"} + for i in range(1, 101) + ] + sql_fn = _make_sql_fn( + { + "source_id": { + "columns": ["source_id", "target_id"], + "rows": edge_rows, + }, + "mothers": _entity_rows([f"NHS-{i:03d}" for i in range(1, 101)]), + "babies": _entity_rows([f"B-{i}" for i in range(1, 101)]), + } + ) + + report = evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=sql_fn, + expected_cross_source_overlap_band=(0.25, 0.4), + ) + # overlap_pct = 1.0 (every source row matches a target id); outside band. + assert report.status == "FAIL" + check_names = [f.check for f in report.failures] + assert "cross_source_overlap_pct" in check_names + + def test_cross_source_band_pass_when_inside(self): + rel = _mother_to_baby_relationship() + # Build edges where only ~30% of source ids match a target id (band 0.25..0.4 ). + edge_rows = [] + for i in range(1, 101): + edge_rows.append( + { + "source_id": f"NHS-{i:03d}", + "target_id": f"B-{i}" if i <= 30 else f"X-{i}", + } + ) + sql_fn = _make_sql_fn( + { + "source_id": { + "columns": ["source_id", "target_id"], + "rows": edge_rows, + }, + "mothers": _entity_rows([f"NHS-{i:03d}" for i in range(1, 101)]), + "babies": _entity_rows([f"B-{i}" for i in range(1, 101)]), + } + ) + + report = evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=sql_fn, + expected_cross_source_overlap_band=(0.25, 0.4), + ) + # overlap = 30/100 = 0.3, inside band. + assert report.status == "PASS" + assert report.metrics["cross_source_overlap_pct"] == pytest.approx(0.3) + + def test_band_present_overlap_outside_band_with_catastrophic_dangling_bubbles(self): + """Band FAILS (overlap 0.05 << lo=0.25) AND dangling > 0.5 → bubble. + + The realised overlap is materially worse than the Planner predicted, + so the catastrophic-dangling structural failure fires alongside the + band-check failure, and ``bubble_to_planner`` flips True. + """ + rel = _mother_to_baby_relationship() + # 100 edges, only the first 5 target_ids land in the babies universe + # → overlap = 0.05, dangling_target = 0.95. + edge_rows = [] + for i in range(1, 101): + edge_rows.append( + { + "source_id": f"NHS-{i:03d}", + "target_id": f"B-{i}" if i <= 5 else f"X-{i}", + } + ) + sql_fn = _make_sql_fn( + { + "source_id": { + "columns": ["source_id", "target_id"], + "rows": edge_rows, + }, + "mothers": _entity_rows([f"NHS-{i:03d}" for i in range(1, 101)]), + "babies": _entity_rows([f"B-{i}" for i in range(1, 101)]), + } + ) + + report = evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=sql_fn, + expected_cross_source_overlap_band=(0.25, 0.4), + ) + + assert report.status == "FAIL" + assert report.bubble_to_planner is True + assert report.metrics["dangling_target_pct"] == pytest.approx(0.95) + check_names = [f.check for f in report.failures] + # Both the band failure AND the catastrophic-dangling row must surface. + assert "cross_source_overlap_pct" in check_names + assert "dangling_target_pct_catastrophic" in check_names + # The strict 0.05 dangling_target_pct row is gated behind "band is None" + # — it must NOT appear here. + assert "dangling_target_pct" not in check_names + + def test_band_present_overlap_outside_band_with_mild_dangling_does_not_bubble(self): + """Band FAILS but dangling is exactly at the bubble threshold (not > 0.5) + → status FAIL on the band row but ``bubble_to_planner`` stays False. + """ + rel = _mother_to_baby_relationship() + # 100 edges, 50 land in target universe → overlap = 0.50, dangling = 0.50. + # Band is (0.6, 0.8) so band check fails (0.50 < 0.6); dangling NOT > 0.5. + edge_rows = [] + for i in range(1, 101): + edge_rows.append( + { + "source_id": f"NHS-{i:03d}", + "target_id": f"B-{i}" if i <= 50 else f"X-{i}", + } + ) + sql_fn = _make_sql_fn( + { + "source_id": { + "columns": ["source_id", "target_id"], + "rows": edge_rows, + }, + "mothers": _entity_rows([f"NHS-{i:03d}" for i in range(1, 101)]), + "babies": _entity_rows([f"B-{i}" for i in range(1, 101)]), + } + ) + + report = evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=sql_fn, + expected_cross_source_overlap_band=(0.6, 0.8), + ) + + assert report.status == "FAIL" + assert report.bubble_to_planner is False + assert report.metrics["dangling_target_pct"] == pytest.approx(0.5) + check_names = [f.check for f in report.failures] + assert "cross_source_overlap_pct" in check_names + # No catastrophic row because dangling is not strictly > 0.5. + assert "dangling_target_pct_catastrophic" not in check_names + + def test_relationship_evaluator_uses_id_universe_cache(self): + """Sharing a cache across calls avoids re-running the entity SQLs.""" + rel = _mother_to_baby_relationship() + base_fn = _make_sql_fn( + { + "source_id": { + "columns": ["source_id", "target_id"], + "rows": [ + {"source_id": "NHS-001", "target_id": "B-1"}, + {"source_id": "NHS-002", "target_id": "B-2"}, + ], + }, + "mothers": _entity_rows(["NHS-001", "NHS-002", "NHS-003"]), + "babies": _entity_rows(["B-1", "B-2", "B-3"]), + } + ) + + calls: List[str] = [] + + def counting_fn(sql: str) -> dict: + calls.append(sql) + return base_fn(sql) + + cache: Dict[str, set] = {} + + # First call: source + target entity SQLs + relationship SQL = 3 calls. + evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=counting_fn, + id_universe_cache=cache, + ) + first_call_count = len(calls) + assert first_call_count == 3 + + mother_sql = _mother_mapping()["sql_query"] + baby_sql = _baby_mapping()["sql_query"] + assert mother_sql in cache + assert baby_sql in cache + + # Second call with same cache: only the relationship SQL should be + # re-executed; both entity universes are served from cache. + evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=counting_fn, + id_universe_cache=cache, + ) + + delta = calls[first_call_count:] + assert len(delta) == 1 + assert mother_sql not in delta + assert baby_sql not in delta + + def test_band_absent_catastrophic_target_dangling_bubbles(self): + """No band supplied + dangling_target > 0.5 → strict check fires and bubbles.""" + rel = _mother_to_baby_relationship() + # 100 edges, only 20 target_ids land in babies universe → dangling = 0.80. + edge_rows = [] + for i in range(1, 101): + edge_rows.append( + { + "source_id": f"NHS-{i:03d}", + "target_id": f"B-{i}" if i <= 20 else f"X-{i}", + } + ) + sql_fn = _make_sql_fn( + { + "source_id": { + "columns": ["source_id", "target_id"], + "rows": edge_rows, + }, + "mothers": _entity_rows([f"NHS-{i:03d}" for i in range(1, 101)]), + "babies": _entity_rows([f"B-{i}" for i in range(1, 101)]), + } + ) + + report = evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=sql_fn, + ) + + assert report.status == "FAIL" + assert report.bubble_to_planner is True + assert report.metrics["dangling_target_pct"] == pytest.approx(0.8) + check_names = [f.check for f in report.failures] + assert "dangling_target_pct" in check_names + + +# ===================================================== +# build_report — bubble demotion warning +# ===================================================== + + +def test_build_report_warns_when_bubble_demoted(caplog): + """``bubble_to_planner=True`` with no failures (status PASS) should + emit a warning, AND silently-PASSing reports should not warn. + """ + import logging + + # PASS + bubble_to_planner=True → warning expected, bubble demoted. + caplog.clear() + with caplog.at_level(logging.WARNING): + passing = build_report( + stage="deterministic", + metrics={"row_count": 1}, + failures=[], + bubble_to_planner=True, + ) + assert passing.status == "PASS" + assert passing.bubble_to_planner is False + assert any( + "bubble_to_planner=True" in rec.message and rec.levelname == "WARNING" + for rec in caplog.records + ) + + # PASS + bubble_to_planner=False → no warning. + caplog.clear() + with caplog.at_level(logging.WARNING): + build_report( + stage="deterministic", + metrics={"row_count": 1}, + failures=[], + bubble_to_planner=False, + ) + assert not any( + "bubble_to_planner=True" in rec.message for rec in caplog.records + ) + + # FAIL + bubble_to_planner=True → no demotion, no warning. + caplog.clear() + failure = EvalFailure( + kind="structural", + check="row_count", + expected="> 0", + observed="0", + hint="", + ) + with caplog.at_level(logging.WARNING): + failing = build_report( + stage="deterministic", + metrics={"row_count": 0}, + failures=[failure], + bubble_to_planner=True, + ) + assert failing.status == "FAIL" + assert failing.bubble_to_planner is True + assert not any( + "bubble_to_planner=True" in rec.message for rec in caplog.records + ) diff --git a/tests/agents/agent_mapping_pge/test_engine.py b/tests/agents/agent_mapping_pge/test_engine.py new file mode 100644 index 00000000..760683eb --- /dev/null +++ b/tests/agents/agent_mapping_pge/test_engine.py @@ -0,0 +1,1117 @@ +"""Tests for the mapping-PGE orchestrator (Sprint 7). + +The orchestrator wires Planner -> Generator(s) -> Evaluator(s) into a single +``run_agent`` entry. These tests exercise the control flow with fake versions +of each sub-agent — no real LLM, no real Databricks. Each test patches the +module-level references in :mod:`engine` so the orchestrator calls the fakes +instead of the production functions. + +What we DO exercise: +* Happy path with both entities and relationships. +* Planner failure aborts cleanly. +* Generator failure records FAIL but continues. +* Evaluator FAIL (non-bubble) drives a retry with a hint. +* Bubble-to-planner triggers Planner re-invocation; budget is global. +* 3-attempt retry budget exhaustion records FAIL_BUDGET. +* Critic PASS / FAIL paths, and the ``skip_semantic_critic`` short-circuit. +* Pre-seeded entity mappings and Planner skip[] entries. +* on_step pct stays non-decreasing across the run. +* Id-universe cache shares entity universes across relationships. +""" + +from typing import Any, Dict, List, Optional, Tuple + +import pytest + +from agents.agent_mapping_pge import engine as engine_mod +from agents.agent_mapping_pge.contracts import ( + CanonicalId, + EvalFailure, + EvalReport, + JoinKey, + MappingPlan, + SkipItem, + SourceModel, + TableRole, + TableRoleCandidate, +) +from agents.agent_mapping_pge.engine import AgentResult, run_agent +from agents.agent_mapping_pge.evaluator.critic import CriticResult +from agents.agent_mapping_pge.generators.entity import EntityGenResult +from agents.agent_mapping_pge.generators.relationship import RelationshipGenResult +from agents.agent_mapping_pge.planner import PlannerResult + + +# ===================================================== +# Ontology + SourceModel fixtures +# ===================================================== + + +CUSTOMER_URI = "http://test.org/ontology#Customer" +ORDER_URI = "http://test.org/ontology#Order" +HAS_ORDER_URI = "http://test.org/ontology#hasOrder" +ITEM_URI = "http://test.org/ontology#Item" +CONTAINS_URI = "http://test.org/ontology#contains" + +T_CUSTOMERS = "cat.sch.customers" +T_ORDERS = "cat.sch.orders" +T_ITEMS = "cat.sch.items" + + +def _ontology() -> dict: + # Control-flow tests use this two-entity / one-relationship ontology so the + # engine-enforced coverage set equals exactly Customer + Order + hasOrder. + # (The id-universe-cache test below supplies its own 3-entity ontology.) + return { + "entities": [ + { + "uri": CUSTOMER_URI, + "name": "Customer", + "label": "Customer", + "attributes": [{"name": "firstName", "type": "xsd:string"}], + }, + { + "uri": ORDER_URI, + "name": "Order", + "label": "Order", + "attributes": [{"name": "orderDate", "type": "xsd:string"}], + }, + ], + "relationships": [ + { + "uri": HAS_ORDER_URI, + "name": "hasOrder", + "label": "hasOrder", + "domain": CUSTOMER_URI, + "range": ORDER_URI, + }, + ], + } + + +def _source_model(*, with_items: bool = False) -> SourceModel: + table_roles = [ + TableRole( + table=T_CUSTOMERS, + ontology_class_candidates=[ + TableRoleCandidate(uri=CUSTOMER_URI, confidence=0.9, reason="ok") + ], + ), + TableRole( + table=T_ORDERS, + ontology_class_candidates=[ + TableRoleCandidate(uri=ORDER_URI, confidence=0.9, reason="ok") + ], + ), + ] + if with_items: + table_roles.append( + TableRole( + table=T_ITEMS, + ontology_class_candidates=[ + TableRoleCandidate(uri=ITEM_URI, confidence=0.9, reason="ok") + ], + ) + ) + canonical_ids = [ + CanonicalId( + ontology_class=CUSTOMER_URI, + canonical_column_per_table={T_CUSTOMERS: "customer_id"}, + ), + CanonicalId( + ontology_class=ORDER_URI, + canonical_column_per_table={T_ORDERS: "order_id"}, + ), + ] + if with_items: + canonical_ids.append( + CanonicalId( + ontology_class=ITEM_URI, + canonical_column_per_table={T_ITEMS: "item_id"}, + ) + ) + join_keys = [ + JoinKey( + from_ref=f"{T_ORDERS}.customer_id", + to_ref=f"{T_CUSTOMERS}.customer_id", + confidence=0.9, + overlap_pct=0.95, + kind="same_trust_fk", + ), + ] + if with_items: + join_keys.append( + JoinKey( + from_ref=f"{T_ITEMS}.order_id", + to_ref=f"{T_ORDERS}.order_id", + confidence=0.9, + overlap_pct=0.95, + kind="same_trust_fk", + ) + ) + + entity_order = [CUSTOMER_URI, ORDER_URI] + relationship_order = [HAS_ORDER_URI] + if with_items: + entity_order.append(ITEM_URI) + relationship_order.append(CONTAINS_URI) + + return SourceModel( + table_roles=table_roles, + canonical_ids=canonical_ids, + join_keys=join_keys, + mapping_plan=MappingPlan( + entity_order=entity_order, + relationship_order=relationship_order, + skip=[], + ), + ) + + +def _entity_mapping(class_uri: str, id_col: str, sql: str) -> dict: + """Shape produced by the EntityGenerator's submit handler.""" + return { + "ontology_class": class_uri, + "class_name": class_uri.rsplit("#", 1)[-1], + "sql_query": sql, + "id_column": id_col, + "label_column": id_col, + "attribute_mappings": {}, + "unmapped_attributes": [], + } + + +def _relationship_mapping( + prop_uri: str, source_col: str, target_col: str, sql: str +) -> dict: + return { + "property": prop_uri, + "property_name": prop_uri.rsplit("#", 1)[-1], + "sql_query": sql, + "source_id_column": source_col, + "target_id_column": target_col, + "domain": CUSTOMER_URI, + "range_class": ORDER_URI, + } + + +# ===================================================== +# Fake sub-agent factories +# ===================================================== + + +class FakePlanner: + """Fake ``run_planner`` returning canned :class:`PlannerResult` values.""" + + def __init__(self, results: List[PlannerResult]): + self.results = list(results) + self.calls = 0 + + def __call__(self, *args: Any, **kwargs: Any) -> PlannerResult: + self.calls += 1 + if not self.results: + raise AssertionError( + f"FakePlanner ran out of canned results on call #{self.calls}" + ) + return self.results.pop(0) + + +class FakeEntityGenerator: + """Routes the call by ontology_class URI to a per-URI list of results.""" + + def __init__(self, results_by_uri: Dict[str, List[EntityGenResult]]): + self.results_by_uri = {k: list(v) for k, v in results_by_uri.items()} + self.calls: List[Tuple[str, Optional[str]]] = [] + + def __call__(self, *args: Any, **kwargs: Any) -> EntityGenResult: + ontology_class = kwargs["ontology_class"] + uri = ontology_class.get("uri", "") + hint = kwargs.get("retry_hint") + self.calls.append((uri, hint)) + queue = self.results_by_uri.get(uri, []) + if not queue: + raise AssertionError( + f"FakeEntityGenerator: no canned result for {uri} (call " + f"#{len(self.calls)})" + ) + return queue.pop(0) + + +class FakeRelationshipGenerator: + """Routes the call by ontology_property URI.""" + + def __init__(self, results_by_uri: Dict[str, List[RelationshipGenResult]]): + self.results_by_uri = {k: list(v) for k, v in results_by_uri.items()} + self.calls: List[Tuple[str, Optional[str]]] = [] + + def __call__(self, *args: Any, **kwargs: Any) -> RelationshipGenResult: + prop = kwargs["ontology_property"] + uri = prop.get("uri", "") + hint = kwargs.get("retry_hint") + self.calls.append((uri, hint)) + queue = self.results_by_uri.get(uri, []) + if not queue: + raise AssertionError( + f"FakeRelationshipGenerator: no canned result for {uri}" + ) + return queue.pop(0) + + +class FakeCritic: + """Routes by item_uri.""" + + def __init__(self, reports_by_uri: Dict[str, List[CriticResult]]): + self.reports_by_uri = {k: list(v) for k, v in reports_by_uri.items()} + self.calls: List[str] = [] + + def __call__(self, *args: Any, **kwargs: Any) -> CriticResult: + uri = kwargs["item_uri"] + self.calls.append(uri) + queue = self.reports_by_uri.get(uri, []) + if not queue: + # Default: PASS so tests that don't care about critic still work. + return CriticResult( + success=True, + report=EvalReport( + status="PASS", + stage="semantic", + metrics={}, + failures=[], + bubble_to_planner=False, + ), + ) + return queue.pop(0) + + +class FakeDeterministicEvaluator: + """Stage-1 evaluator stub keyed by mapping uri (class or property).""" + + def __init__(self, reports_by_uri: Dict[str, List[EvalReport]]): + self.reports_by_uri = {k: list(v) for k, v in reports_by_uri.items()} + self.calls: List[str] = [] + + def for_entity(self, *args: Any, **kwargs: Any) -> EvalReport: + mapping = kwargs["mapping"] + uri = mapping.get("ontology_class", "") + return self._next(uri) + + def for_relationship(self, *args: Any, **kwargs: Any) -> EvalReport: + mapping = kwargs["mapping"] + uri = mapping.get("property", "") + return self._next(uri) + + def _next(self, uri: str) -> EvalReport: + self.calls.append(uri) + queue = self.reports_by_uri.get(uri, []) + if not queue: + return EvalReport( + status="PASS", + stage="deterministic", + metrics={}, + failures=[], + bubble_to_planner=False, + ) + return queue.pop(0) + + +# ===================================================== +# Helpers — build typical canned results +# ===================================================== + + +def _ok_entity_gen(class_uri: str, sql: Optional[str] = None) -> EntityGenResult: + id_col = { + CUSTOMER_URI: "customer_id", + ORDER_URI: "order_id", + ITEM_URI: "item_id", + }.get(class_uri, "id") + sql = sql or f"SELECT {id_col} AS ID, {id_col} AS Label FROM tbl_for_{class_uri[-3:]}" + return EntityGenResult( + success=True, + mapping=_entity_mapping(class_uri, id_col, sql), + iterations=2, + usage={"prompt_tokens": 10, "completion_tokens": 5}, + ) + + +def _ok_rel_gen(prop_uri: str) -> RelationshipGenResult: + sql = "SELECT customer_id AS source_id, order_id AS target_id FROM orders" + return RelationshipGenResult( + success=True, + mapping=_relationship_mapping(prop_uri, "customer_id", "order_id", sql), + iterations=2, + usage={"prompt_tokens": 10, "completion_tokens": 5}, + ) + + +def _pass_report(stage: str = "deterministic") -> EvalReport: + return EvalReport( + status="PASS", + stage=stage, + metrics={"row_count": 100}, + failures=[], + bubble_to_planner=False, + ) + + +def _fail_report( + *, + stage: str = "deterministic", + hint: str = "wrong column", + bubble: bool = False, +) -> EvalReport: + return EvalReport( + status="FAIL", + stage=stage, + metrics={"row_count": 5}, + failures=[ + EvalFailure( + kind="structural" if stage == "deterministic" else "semantic", + check="some_check", + expected=">0", + observed="0", + hint=hint, + ) + ], + bubble_to_planner=bubble, + ) + + +# ===================================================== +# Common fixtures +# ===================================================== + + +@pytest.fixture +def fake_client() -> Any: + """Lightweight stub with the ``execute_query`` method the orchestrator wraps.""" + + class _Client: + def __init__(self): + self.calls: List[str] = [] + + def execute_query(self, sql: str): + self.calls.append(sql) + # Echo three rows so ``row_count > 0`` if the real evaluator is + # invoked. (Most tests stub the evaluator and never hit this.) + return [ + {"customer_id": 1, "order_id": 10}, + {"customer_id": 2, "order_id": 20}, + {"customer_id": 3, "order_id": 30}, + ] + + return _Client() + + +def _patch_sub_agents( + monkeypatch, + *, + planner: Any, + entity_gen: Any = None, + rel_gen: Any = None, + critic: Any = None, + det_eval: Optional[FakeDeterministicEvaluator] = None, +) -> None: + monkeypatch.setattr(engine_mod, "run_planner", planner) + if entity_gen is not None: + monkeypatch.setattr(engine_mod, "run_entity_generator", entity_gen) + if rel_gen is not None: + monkeypatch.setattr(engine_mod, "run_relationship_generator", rel_gen) + if critic is not None: + monkeypatch.setattr(engine_mod, "run_critic", critic) + if det_eval is not None: + monkeypatch.setattr( + engine_mod, "evaluate_entity_mapping", det_eval.for_entity + ) + monkeypatch.setattr( + engine_mod, + "evaluate_relationship_mapping", + det_eval.for_relationship, + ) + + +def _run(client: Any, **overrides) -> AgentResult: + kwargs = dict( + host="https://test", + token="t", + endpoint_name="ep", + client=client, + metadata={}, + ontology=_ontology(), + skip_semantic_critic=True, + ) + kwargs.update(overrides) + return run_agent(**kwargs) + + +# ===================================================== +# Tests +# ===================================================== + + +def test_happy_path_two_entities_one_relationship(monkeypatch, fake_client): + planner = FakePlanner( + [PlannerResult(success=True, source_model=_source_model(), iterations=1)] + ) + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [_ok_entity_gen(CUSTOMER_URI)], + ORDER_URI: [_ok_entity_gen(ORDER_URI)], + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + det = FakeDeterministicEvaluator({}) # all default PASS + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + det_eval=det, + ) + + result = _run(fake_client) + + assert result.success is True + assert len(result.entity_mappings) == 2 + assert len(result.relationship_mappings) == 1 + assert {m["ontology_class"] for m in result.entity_mappings} == { + CUSTOMER_URI, + ORDER_URI, + } + # 3 mapping_run_log entries, all PASS. + assert len(result.mapping_run_log) == 3 + assert all(entry["final_status"] == "PASS" for entry in result.mapping_run_log) + # source_model serialised onto the result. + assert result.source_model is not None + assert "table_roles" in result.source_model + + +def test_planner_failure_aborts(monkeypatch, fake_client): + planner = FakePlanner( + [PlannerResult(success=False, source_model=None, error="LLM rejected tools")] + ) + _patch_sub_agents(monkeypatch, planner=planner) + + result = _run(fake_client) + + assert result.success is False + assert "LLM rejected tools" in result.error + assert result.entity_mappings == [] + assert result.relationship_mappings == [] + + +def test_generator_failure_records_item_failure_continues_run( + monkeypatch, fake_client +): + planner = FakePlanner( + [PlannerResult(success=True, source_model=_source_model(), iterations=1)] + ) + # Customer generator fails 3 times; Order succeeds. + fail = EntityGenResult(success=False, mapping=None, error="generator crashed") + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [fail, fail, fail], + ORDER_URI: [_ok_entity_gen(ORDER_URI)], + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + det = FakeDeterministicEvaluator({}) + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + det_eval=det, + ) + + result = _run(fake_client) + + # Order entity mapped despite Customer failing. + assert any(m["ontology_class"] == ORDER_URI for m in result.entity_mappings) + assert not any(m["ontology_class"] == CUSTOMER_URI for m in result.entity_mappings) + customer_log = next( + e for e in result.mapping_run_log if e["item"] == CUSTOMER_URI + ) + assert customer_log["final_status"] == "FAIL_BUDGET" + # The relationship endpoint for hasOrder requires Customer; with Customer + # missing the relationship is recorded but not mapped. + rel_log = next(e for e in result.mapping_run_log if e["item"] == HAS_ORDER_URI) + assert rel_log["final_status"] in {"FAIL_BUDGET", "PASS"} + + +def test_evaluator_fail_retry_with_hint(monkeypatch, fake_client): + planner = FakePlanner( + [PlannerResult(success=True, source_model=_source_model(), iterations=1)] + ) + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [ + _ok_entity_gen(CUSTOMER_URI, "SELECT bad_col AS ID FROM x"), + _ok_entity_gen(CUSTOMER_URI), + ], + ORDER_URI: [_ok_entity_gen(ORDER_URI)], + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + # First attempt fails, second passes — non-bubble FAIL with a hint. + det = FakeDeterministicEvaluator( + { + CUSTOMER_URI: [ + _fail_report(hint="use customer_id, not bad_col", bubble=False), + _pass_report(), + ] + } + ) + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + det_eval=det, + ) + + result = _run(fake_client) + + assert result.success is True + customer_log = next( + e for e in result.mapping_run_log if e["item"] == CUSTOMER_URI + ) + assert customer_log["final_status"] == "PASS" + assert len(customer_log["attempts"]) == 2 + assert customer_log["attempts"][0]["stage1_status"] == "FAIL" + assert customer_log["attempts"][0]["hint"] == "use customer_id, not bad_col" + # Second EntityGenerator call must have been given the hint. + customer_calls = [c for c in entity_gen.calls if c[0] == CUSTOMER_URI] + assert customer_calls[0][1] is None + assert customer_calls[1][1] == "use customer_id, not bad_col" + + +def test_bubble_to_planner_triggers_replanning(monkeypatch, fake_client): + planner = FakePlanner( + [ + PlannerResult(success=True, source_model=_source_model(), iterations=1), + PlannerResult(success=True, source_model=_source_model(), iterations=1), + ] + ) + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [ + _ok_entity_gen(CUSTOMER_URI), # attempt 1 (bubbles) + _ok_entity_gen(CUSTOMER_URI), # attempt 1 of replan iteration + ], + ORDER_URI: [_ok_entity_gen(ORDER_URI)], + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + det = FakeDeterministicEvaluator( + { + CUSTOMER_URI: [ + _fail_report(hint="wrong table", bubble=True), + _pass_report(), + ] + } + ) + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + det_eval=det, + ) + + result = _run(fake_client) + + assert result.success is True + customer_log = next( + e for e in result.mapping_run_log if e["item"] == CUSTOMER_URI + ) + assert customer_log["final_status"] == "PASS" + # Planner was invoked twice (initial + 1 replan). + assert planner.calls == 2 + assert result.stats["planner_reinvocations"] == 1 + + +def test_planner_reinvocation_budget_exhausted(monkeypatch, fake_client): + # Budget-agnostic: 1 initial planner call + exactly the replan budget. + budget = engine_mod._PLANNER_REINVOCATION_BUDGET + planner = FakePlanner( + [ + PlannerResult(success=True, source_model=_source_model(), iterations=1) + for _ in range(budget + 1) + ] + ) + bubble = _fail_report(hint="wrong table", bubble=True) + # Customer entity bubbles on every attempt forever. + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [_ok_entity_gen(CUSTOMER_URI) for _ in range(40)], + ORDER_URI: [_ok_entity_gen(ORDER_URI)], + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + det = FakeDeterministicEvaluator( + { + CUSTOMER_URI: [bubble for _ in range(40)], + } + ) + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + det_eval=det, + ) + + result = _run(fake_client) + + customer_log = next( + e for e in result.mapping_run_log if e["item"] == CUSTOMER_URI + ) + assert customer_log["final_status"] == "FAIL_BUBBLE" + # 1 initial planner call + exactly `budget` replans. + assert planner.calls == budget + 1 + assert result.stats["planner_reinvocations"] == budget + # Other items still attempted; Order succeeded. + assert any(m["ontology_class"] == ORDER_URI for m in result.entity_mappings) + + +def test_retry_budget_exhausted_marks_item_fail_budget(monkeypatch, fake_client): + planner = FakePlanner( + [PlannerResult(success=True, source_model=_source_model(), iterations=1)] + ) + attempts = engine_mod._PER_ITEM_GENERATOR_ATTEMPTS + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [_ok_entity_gen(CUSTOMER_URI) for _ in range(attempts + 2)], + ORDER_URI: [_ok_entity_gen(ORDER_URI)], + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + det = FakeDeterministicEvaluator( + { + CUSTOMER_URI: [ + _fail_report(hint=f"hint-{i}", bubble=False) + for i in range(attempts) + ], + } + ) + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + det_eval=det, + ) + + result = _run(fake_client) + + customer_log = next( + e for e in result.mapping_run_log if e["item"] == CUSTOMER_URI + ) + assert customer_log["final_status"] == "FAIL_BUDGET" + assert len(customer_log["attempts"]) == attempts + assert all(a["stage1_status"] == "FAIL" for a in customer_log["attempts"]) + assert planner.calls == 1 + # Order still mapped. + assert any(m["ontology_class"] == ORDER_URI for m in result.entity_mappings) + + +def test_critic_pass_full_pipeline(monkeypatch, fake_client): + planner = FakePlanner( + [PlannerResult(success=True, source_model=_source_model(), iterations=1)] + ) + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [_ok_entity_gen(CUSTOMER_URI)], + ORDER_URI: [_ok_entity_gen(ORDER_URI)], + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + det = FakeDeterministicEvaluator({}) # default PASS + critic = FakeCritic( + { + CUSTOMER_URI: [CriticResult(success=True, report=_pass_report("semantic"))], + ORDER_URI: [CriticResult(success=True, report=_pass_report("semantic"))], + HAS_ORDER_URI: [ + CriticResult(success=True, report=_pass_report("semantic")) + ], + } + ) + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + critic=critic, + det_eval=det, + ) + + result = _run(fake_client, skip_semantic_critic=False) + + assert result.success is True + customer_log = next( + e for e in result.mapping_run_log if e["item"] == CUSTOMER_URI + ) + last_attempt = customer_log["attempts"][-1] + assert last_attempt["stage1_status"] == "PASS" + assert last_attempt["critic_status"] == "PASS" + # Critic was actually called. + assert CUSTOMER_URI in critic.calls + + +def test_critic_fail_with_bubble(monkeypatch, fake_client): + planner = FakePlanner( + [ + PlannerResult(success=True, source_model=_source_model(), iterations=1), + PlannerResult(success=True, source_model=_source_model(), iterations=1), + ] + ) + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [ + _ok_entity_gen(CUSTOMER_URI), # initial attempt — critic bubbles + _ok_entity_gen(CUSTOMER_URI), # post-replan attempt — passes + ], + ORDER_URI: [_ok_entity_gen(ORDER_URI)], + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + det = FakeDeterministicEvaluator({}) # default PASS on stage 1 + critic = FakeCritic( + { + CUSTOMER_URI: [ + CriticResult( + success=True, + report=_fail_report( + stage="semantic", hint="wrong table", bubble=True + ), + ), + CriticResult(success=True, report=_pass_report("semantic")), + ], + ORDER_URI: [CriticResult(success=True, report=_pass_report("semantic"))], + HAS_ORDER_URI: [ + CriticResult(success=True, report=_pass_report("semantic")) + ], + } + ) + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + critic=critic, + det_eval=det, + ) + + result = _run(fake_client, skip_semantic_critic=False) + + assert result.success is True + assert planner.calls == 2 + customer_log = next( + e for e in result.mapping_run_log if e["item"] == CUSTOMER_URI + ) + assert customer_log["final_status"] == "PASS" + + +def test_skip_semantic_critic_true_skips_critic(monkeypatch, fake_client): + planner = FakePlanner( + [PlannerResult(success=True, source_model=_source_model(), iterations=1)] + ) + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [_ok_entity_gen(CUSTOMER_URI)], + ORDER_URI: [_ok_entity_gen(ORDER_URI)], + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + det = FakeDeterministicEvaluator({}) + critic = FakeCritic({}) # would default-PASS if called + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + critic=critic, + det_eval=det, + ) + + result = _run(fake_client, skip_semantic_critic=True) + + assert result.success is True + # Critic was never called. + assert critic.calls == [] + # Every attempt records critic_status="skipped". + for entry in result.mapping_run_log: + for attempt in entry["attempts"]: + assert attempt["critic_status"] == "skipped" + + +def test_preseeded_entity_skipped(monkeypatch, fake_client): + planner = FakePlanner( + [PlannerResult(success=True, source_model=_source_model(), iterations=1)] + ) + entity_gen = FakeEntityGenerator( + { + # Customer must NOT be generated — it's pre-seeded. + ORDER_URI: [_ok_entity_gen(ORDER_URI)], + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + det = FakeDeterministicEvaluator({}) + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + det_eval=det, + ) + + pre = [ + _entity_mapping( + CUSTOMER_URI, + "customer_id", + "SELECT customer_id AS ID FROM cat.sch.customers", + ) + ] + + result = _run(fake_client, entity_mappings=pre) + + assert result.success is True + customer_log = next( + e for e in result.mapping_run_log if e["item"] == CUSTOMER_URI + ) + assert customer_log["final_status"] == "PRESEEDED" + assert customer_log["attempts"] == [] + # The pre-seeded mapping is still in the result list. + assert any(m["ontology_class"] == CUSTOMER_URI for m in result.entity_mappings) + # EntityGenerator never called for Customer. + assert not any(c[0] == CUSTOMER_URI for c in entity_gen.calls) + + +def test_skip_list_is_advisory_item_still_mapped(monkeypatch, fake_client): + # Coverage is engine-enforced: the Planner's skip[] is ADVISORY only and + # must NOT remove an ontology class from the mapping run. Even when the + # Planner asks to skip Order, the engine still maps it. + sm = _source_model() + sm.mapping_plan.skip.append(SkipItem(item=ORDER_URI, reason="planner unsure")) + sm.mapping_plan.entity_order = [CUSTOMER_URI, ORDER_URI] + + planner = FakePlanner([PlannerResult(success=True, source_model=sm, iterations=1)]) + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [_ok_entity_gen(CUSTOMER_URI)], + ORDER_URI: [_ok_entity_gen(ORDER_URI)], # MUST still be attempted + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + det = FakeDeterministicEvaluator({}) + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + det_eval=det, + ) + + result = _run(fake_client) + + order_log = next(e for e in result.mapping_run_log if e["item"] == ORDER_URI) + assert order_log["final_status"] == "PASS" + assert any(c[0] == ORDER_URI for c in entity_gen.calls) + assert any(m["ontology_class"] == ORDER_URI for m in result.entity_mappings) + + +def test_on_step_pct_monotonic(monkeypatch, fake_client): + planner = FakePlanner( + [PlannerResult(success=True, source_model=_source_model(), iterations=1)] + ) + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [_ok_entity_gen(CUSTOMER_URI)], + ORDER_URI: [_ok_entity_gen(ORDER_URI)], + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + det = FakeDeterministicEvaluator({}) + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + det_eval=det, + ) + + pcts: List[int] = [] + + def on_step(msg: str, pct: int) -> None: + pcts.append(pct) + + result = _run(fake_client, on_step=on_step) + + assert result.success is True + assert pcts, "expected at least one on_step call" + # Monotonic non-decreasing — captures the documented design contract + # (we only replan on bubble, which this test does not trigger). + for prev, curr in zip(pcts, pcts[1:]): + assert curr >= prev, f"pct went backwards: {prev} -> {curr}" + # First call planning at low pct, last call completion at 100. + assert pcts[0] <= 5 + assert pcts[-1] == 100 + + +def test_id_universe_cache_used_across_relationships(monkeypatch, fake_client): + # Bare ontology with no attributes so the real deterministic evaluator + # doesn't fire on unmapped_attribute_pct. + bare_ontology = { + "entities": [ + {"uri": CUSTOMER_URI, "name": "Customer", "label": "Customer", "attributes": []}, + {"uri": ORDER_URI, "name": "Order", "label": "Order", "attributes": []}, + {"uri": ITEM_URI, "name": "Item", "label": "Item", "attributes": []}, + ], + "relationships": [ + { + "uri": HAS_ORDER_URI, + "name": "hasOrder", + "label": "hasOrder", + "domain": CUSTOMER_URI, + "range": ORDER_URI, + }, + { + "uri": CONTAINS_URI, + "name": "contains", + "label": "contains", + "domain": ORDER_URI, + "range": ITEM_URI, + }, + ], + } + + # Distinct, recognisable SQL strings per entity — used both as cache keys + # and as a discriminator for the CountingClient routing below. + customer_sql = "SELECT customer_id AS ID, customer_id AS Label FROM cat.sch.customers" + order_sql = "SELECT order_id AS ID, order_id AS Label FROM cat.sch.orders" + item_sql = "SELECT item_id AS ID, item_id AS Label FROM cat.sch.items" + + planner = FakePlanner( + [ + PlannerResult( + success=True, source_model=_source_model(with_items=True), iterations=1 + ) + ] + ) + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [ + EntityGenResult( + success=True, + mapping=_entity_mapping(CUSTOMER_URI, "customer_id", customer_sql), + iterations=1, + ) + ], + ORDER_URI: [ + EntityGenResult( + success=True, + mapping=_entity_mapping(ORDER_URI, "order_id", order_sql), + iterations=1, + ) + ], + ITEM_URI: [ + EntityGenResult( + success=True, + mapping=_entity_mapping(ITEM_URI, "item_id", item_sql), + iterations=1, + ) + ], + } + ) + + # Relationship edges return rows whose source/target values fall inside + # the entity universes so the deterministic evaluator passes. + has_order_sql = "SELECT customer_id AS source_id, order_id AS target_id FROM has_order_edge" + contains_sql = "SELECT order_id AS source_id, item_id AS target_id FROM contains_edge" + + rel_gen = FakeRelationshipGenerator( + { + HAS_ORDER_URI: [ + RelationshipGenResult( + success=True, + mapping={ + "property": HAS_ORDER_URI, + "property_name": "hasOrder", + "sql_query": has_order_sql, + "source_id_column": "source_id", + "target_id_column": "target_id", + }, + iterations=1, + ) + ], + CONTAINS_URI: [ + RelationshipGenResult( + success=True, + mapping={ + "property": CONTAINS_URI, + "property_name": "contains", + "sql_query": contains_sql, + "source_id_column": "source_id", + "target_id_column": "target_id", + }, + iterations=1, + ) + ], + } + ) + # Use the REAL deterministic evaluators here so the cache codepath is + # actually exercised against execute_sql_fn. + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + # No det_eval override -> real evaluators used. + ) + + class CountingClient: + """Records every ``execute_query`` call so we can count cache hits.""" + + def __init__(self): + self.sql_calls: List[str] = [] + + def execute_query(self, sql: str): + self.sql_calls.append(sql) + # Entity-universe queries return rows keyed by the entity's id_column. + if sql == customer_sql: + return [{"customer_id": i, "ID": i} for i in range(1, 4)] + if sql == order_sql: + return [{"order_id": i, "ID": i} for i in range(1, 4)] + if sql == item_sql: + return [{"item_id": i, "ID": i} for i in range(1, 4)] + # Edge SQLs: values must overlap with the entity universes so + # dangling_*_pct stays low and the report PASSes. + if sql == has_order_sql: + return [ + {"source_id": i, "target_id": i, "customer_id": i, "order_id": i} + for i in range(1, 4) + ] + if sql == contains_sql: + return [ + {"source_id": i, "target_id": i, "order_id": i, "item_id": i} + for i in range(1, 4) + ] + return [] + + client = CountingClient() + result = _run(client, ontology=bare_ontology) + + assert len(result.entity_mappings) == 3, result.mapping_run_log + assert len(result.relationship_mappings) == 2, result.mapping_run_log + + # Each unique entity SQL is run by the entity evaluator (1) + at most + # ONCE more from the first relationship that references it (cached for + # subsequent relationships). Without the cache, each entity SQL would + # fire from EVERY relationship that touches it — order_sql in + # particular would run 1 (entity) + 2 (hasOrder + contains) = 3 times. + for sql in (customer_sql, order_sql, item_sql): + count = sum(1 for c in client.sql_calls if c == sql) + assert count <= 2, ( + f"entity SQL ran {count} times — id_universe_cache failed:\n{sql}" + ) diff --git a/tests/agents/agent_mapping_pge/test_entity_generator.py b/tests/agents/agent_mapping_pge/test_entity_generator.py new file mode 100644 index 00000000..ed09cb03 --- /dev/null +++ b/tests/agents/agent_mapping_pge/test_entity_generator.py @@ -0,0 +1,662 @@ +"""Tests for the mapping-PGE EntityGenerator agent (Sprint 4). + +The Generator is a narrow tool-calling ReAct loop terminated by +``submit_entity_mapping``. These tests exercise the loop's control flow with +a *fake LLM* — a stub that replaces ``call_serving_endpoint`` at module +level and returns canned tool-call responses on a per-call basis. + +No real HTTP, no real Databricks, no MLflow tracing. + +What we DO exercise: +* Termination on a single submit call. +* Multi-step trajectory (execute_sql → submit). +* ``unmapped_attributes`` round-trips through the tool to the result. +* Text-only output is treated as failure (no terminal call). +* Iteration-budget exhaustion is treated as failure. +* ``retry_hint`` surfaces inside the user message. +* Step recording: every tool call produces both tool_call and tool_result + steps in the right order. +""" + +import json +from typing import Any, Callable, Dict, List, Optional + +import pytest + +from agents.agent_mapping_pge.generators import entity as entity_mod +from agents.agent_mapping_pge.generators.entity import ( + EntityGenResult, + EntityGenStep, + run_entity_generator, +) + + +# ===================================================== +# Fake LLM scaffolding (mirrors test_planner.py) +# ===================================================== + + +_CLASS_URI = "http://ex.org/maternity#Mother" + + +def _make_tool_call(name: str, arguments: dict, *, tc_id: str = "tc1") -> dict: + return { + "id": tc_id, + "type": "function", + "function": {"name": name, "arguments": json.dumps(arguments)}, + } + + +def _llm_response( + *, + tool_calls: Optional[List[dict]] = None, + content: Optional[str] = None, + finish_reason: str = "tool_calls", + usage: Optional[Dict[str, int]] = None, +) -> dict: + message: Dict[str, Any] = {"role": "assistant"} + if tool_calls: + message["tool_calls"] = tool_calls + if content is not None: + message["content"] = content + return { + "choices": [{"finish_reason": finish_reason, "message": message}], + "usage": usage or {"prompt_tokens": 10, "completion_tokens": 5}, + } + + +class FakeLLM: + def __init__(self, responses: List[dict]): + self.responses = list(responses) + self.calls = 0 + # Capture the messages list as observed on each call, so tests can + # introspect what the agent put into the prompt. + self.last_messages: Optional[List[dict]] = None + self.first_messages: Optional[List[dict]] = None + + def __call__(self, *args, **kwargs) -> dict: + self.calls += 1 + # ``call_serving_endpoint(host, token, endpoint, messages, ...)`` — + # the messages list is positional arg #3 (zero-indexed). Capture + # defensively in case the call site changes to kwargs. + msgs: Optional[List[dict]] = None + if len(args) >= 4 and isinstance(args[3], list): + msgs = args[3] + elif "messages" in kwargs: + msgs = kwargs["messages"] + if msgs is not None: + # snapshot so later mutations by the loop do not affect what we + # captured. + snapshot = [dict(m) for m in msgs] + if self.first_messages is None: + self.first_messages = snapshot + self.last_messages = snapshot + + if not self.responses: + raise AssertionError( + f"FakeLLM: ran out of canned responses on call #{self.calls}" + ) + return self.responses.pop(0) + + +class CyclingFakeLLM: + """Like FakeLLM but cycles through a fixed list forever.""" + + def __init__(self, responses: List[dict]): + self.responses = list(responses) + self.calls = 0 + + def __call__(self, *args, **kwargs) -> dict: + resp = self.responses[self.calls % len(self.responses)] + self.calls += 1 + return resp + + +@pytest.fixture +def no_sleep(monkeypatch): + """Neutralise the 3-second inter-iteration delay so tests run fast.""" + monkeypatch.setattr(entity_mod.time, "sleep", lambda *_a, **_k: None) + + +def _patch_llm(monkeypatch, fake: Callable[..., dict]) -> None: + monkeypatch.setattr(entity_mod, "call_serving_endpoint", fake) + + +# ===================================================== +# Fixtures +# ===================================================== + + +def _ontology_class() -> dict: + return { + "uri": _CLASS_URI, + "label": "Mother", + "name": "Mother", + "comment": "A mother in the maternity trust dataset.", + "attributes": [ + {"name": "nhsNumber", "type": "string"}, + {"name": "dateOfBirth", "type": "date"}, + {"name": "ethnicity", "type": "string"}, + ], + } + + +def _source_model_slice() -> dict: + return { + "candidate_tables": [ + { + "table": "cat.sch.mothers", + "confidence": 0.92, + "reason": "row per NHS — mother demographics", + } + ], + "canonical_id": { + "canonical_column_per_table": {"cat.sch.mothers": "nhs_number"}, + "format_note": "10-digit NHS", + }, + "relevant_joins": [], + } + + +def _valid_submit_args( + *, + unmapped: Optional[list] = None, +) -> dict: + args: Dict[str, Any] = { + "class_uri": _CLASS_URI, + "class_name": "Mother", + "sql_query": ( + "SELECT nhs_number AS ID, nhs_number AS Label, nhs_number, dob, ethnicity " + "FROM cat.sch.mothers WHERE nhs_number IS NOT NULL" + ), + "id_column": "nhs_number", + "label_column": "nhs_number", + "attribute_mappings": { + "nhsNumber": "nhs_number", + "dateOfBirth": "dob", + "ethnicity": "ethnicity", + }, + } + if unmapped is not None: + args["unmapped_attributes"] = unmapped + return args + + +# ===================================================== +# 1. Single-shot submit terminates immediately +# ===================================================== + + +def test_terminates_on_submit(monkeypatch, no_sleep): + """First LLM turn submits a valid mapping → success, iterations=1.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call("submit_entity_mapping", _valid_submit_args()) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = run_entity_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_class=_ontology_class(), + source_model_slice=_source_model_slice(), + ) + + assert isinstance(result, EntityGenResult) + assert result.success is True + assert result.iterations == 1 + assert result.mapping is not None + assert result.mapping["ontology_class"] == _CLASS_URI + assert result.mapping["id_column"] == "nhs_number" + assert result.error == "" + step_kinds = [s.step_type for s in result.steps] + assert step_kinds == ["tool_call", "tool_result"] + assert result.steps[0].tool_name == "submit_entity_mapping" + + +# ===================================================== +# 2. execute_sql validation, then submit +# ===================================================== + + +def test_validates_sql_then_submits(monkeypatch, no_sleep): + """execute_sql → submit_entity_mapping → success, iterations=2.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "execute_sql", + { + "sql": ( + "SELECT nhs_number AS ID, nhs_number AS Label, " + "nhs_number, dob, ethnicity FROM cat.sch.mothers " + "WHERE nhs_number IS NOT NULL" + ) + }, + tc_id="a", + ) + ] + ), + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_entity_mapping", _valid_submit_args(), tc_id="b" + ) + ] + ), + ] + ) + _patch_llm(monkeypatch, fake) + + class FakeClient: + def execute_query(self, sql): + return [ + { + "ID": "1234567890", + "Label": "1234567890", + "nhs_number": "1234567890", + "dob": "1990-01-01", + "ethnicity": "white", + } + ] + + result = run_entity_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=FakeClient(), + ontology_class=_ontology_class(), + source_model_slice=_source_model_slice(), + ) + + assert result.success is True + assert result.iterations == 2 + assert result.mapping is not None + # Sequence: tool_call(execute_sql), tool_result(execute_sql), + # tool_call(submit), tool_result(submit) — 4 steps. + assert len(result.steps) == 4 + assert [s.tool_name for s in result.steps] == [ + "execute_sql", + "execute_sql", + "submit_entity_mapping", + "submit_entity_mapping", + ] + + +# ===================================================== +# 3. unmapped_attributes round-trips +# ===================================================== + + +def test_unmapped_attributes_round_trip(monkeypatch, no_sleep): + """Submit with ``unmapped_attributes`` — the field must appear on the + resulting mapping dict in the same (normalised) shape.""" + unmapped_payload = [ + {"name": "ethnicity", "reason": "no ethnicity column in this table"} + ] + args = _valid_submit_args(unmapped=unmapped_payload) + # Strip ethnicity from attribute_mappings to make the example coherent. + args["attribute_mappings"].pop("ethnicity", None) + + fake = FakeLLM( + [ + _llm_response( + tool_calls=[_make_tool_call("submit_entity_mapping", args)] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = run_entity_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_class=_ontology_class(), + source_model_slice=_source_model_slice(), + ) + + assert result.success is True + assert result.mapping is not None + assert result.mapping["unmapped_attributes"] == [ + {"name": "ethnicity", "reason": "no ethnicity column in this table"} + ] + # Plain-string form is also documented; make sure it survives too. + fake2 = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_entity_mapping", + _valid_submit_args(unmapped=["ethnicity"]), + ) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake2) + result2 = run_entity_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_class=_ontology_class(), + source_model_slice=_source_model_slice(), + ) + assert result2.success is True + assert result2.mapping["unmapped_attributes"] == ["ethnicity"] + + +# ===================================================== +# 4. Text without terminal call → failure +# ===================================================== + + +def test_text_without_terminal_fails(monkeypatch, no_sleep): + """A plain-text response is treated as failure — the Generator must + terminate via submit_entity_mapping. + """ + fake = FakeLLM( + [_llm_response(content="I am thinking…", finish_reason="stop")] + ) + _patch_llm(monkeypatch, fake) + + result = run_entity_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_class=_ontology_class(), + source_model_slice=_source_model_slice(), + ) + + assert result.success is False + assert result.iterations == 1 + assert result.mapping is None + assert "without submitting mapping" in result.error + assert any(s.step_type == "output" for s in result.steps) + + +# ===================================================== +# 5. Iteration-budget exhaustion → failure +# ===================================================== + + +def test_exhausts_iteration_budget(monkeypatch, no_sleep): + """Endless sample_table calls with max_iterations=3 → fail with + ``iteration budget`` and three iterations of steps recorded.""" + fake = CyclingFakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "sample_table", + {"full_name": "cat.sch.mothers"}, + tc_id="a", + ) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + class FakeClient: + def execute_query(self, sql): + return [{"nhs_number": "1234567890"}] + + result = run_entity_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=FakeClient(), + ontology_class=_ontology_class(), + source_model_slice=_source_model_slice(), + max_iterations=3, + ) + + assert result.success is False + assert result.iterations == 3 + assert result.mapping is None + assert "iteration budget" in result.error + # 3 iterations × (tool_call + tool_result) = 6 steps. + assert len(result.steps) == 6 + + +# ===================================================== +# 6. retry_hint surfaces in the user prompt +# ===================================================== + + +def test_retry_hint_surfaces_in_user_prompt(monkeypatch, no_sleep): + """If ``retry_hint`` is provided, the FIRST LLM call's user message must + contain the hint verbatim.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call("submit_entity_mapping", _valid_submit_args()) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = run_entity_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_class=_ontology_class(), + source_model_slice=_source_model_slice(), + retry_hint="Use NHS column, not patient_id.", + ) + + assert result.success is True + assert fake.first_messages is not None + # messages[0] is system, messages[1] is user. + assert fake.first_messages[0]["role"] == "system" + assert fake.first_messages[1]["role"] == "user" + user_content = fake.first_messages[1]["content"] + assert "Use NHS column, not patient_id." in user_content + # RETRY HINT label is present so the LLM understands its provenance. + assert "RETRY HINT" in user_content + + +def test_system_prompt_treats_canonical_value_as_sql_expression(monkeypatch, no_sleep): + """canonical_column_per_table values may be SQL expressions (e.g. + ``regexp_extract(...)``) not just bare columns. The Generator must drop + them verbatim. Live V1.1 smoke surfaced 100% dangling on cross-trust + entities whose canonical IDs needed regex normalization to a common + format.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call("submit_entity_mapping", _valid_submit_args()) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + run_entity_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_class=_ontology_class(), + source_model_slice=_source_model_slice(), + ) + + system_content = fake.first_messages[0]["content"] + assert "SQL EXPRESSION" in system_content + assert "regexp_extract" in system_content + assert "verbatim" in system_content + assert "do NOT rewrite" in system_content + + +def test_system_prompt_mandates_union_for_cross_source(monkeypatch, no_sleep): + """When canonical_id.canonical_column_per_table lists 2+ tables, the + Generator MUST UNION across all of them. Single-trust selection on a + cross-source class makes relationship dangling 100% — the failure mode + surfaced by the live V1.1 smoke.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call("submit_entity_mapping", _valid_submit_args()) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + run_entity_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_class=_ontology_class(), + source_model_slice=_source_model_slice(), + ) + + system_content = fake.first_messages[0]["content"] + assert "SINGLE-SOURCE vs CROSS-SOURCE" in system_content + assert "UNION ALL" in system_content + assert "TWO OR MORE tables" in system_content + # Anti-pattern: picking one trust is called out by name. + assert "Picking just one" in system_content or "missing" in system_content + + +# ===================================================== +# 7. Step recording invariants +# ===================================================== + + +def test_wrong_class_uri_submission_does_not_terminate(monkeypatch, no_sleep): + """A submit_entity_mapping call with a class_uri that doesn't match the + requested one must NOT terminate the loop. The Generator must keep going + so a follow-up submit (with the correct URI) can succeed, and the LLM + must see a corrective tool message describing the mismatch. + """ + requested_uri = _CLASS_URI + other_uri = "http://ex.org/maternity#Baby" + + wrong_args = _valid_submit_args() + wrong_args["class_uri"] = other_uri + wrong_args["class_name"] = "Baby" + + fake = FakeLLM( + [ + # Turn 1: submit with the WRONG class_uri — must NOT terminate. + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_entity_mapping", wrong_args, tc_id="wrong" + ) + ] + ), + # Turn 2: submit with the correct class_uri — should terminate. + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_entity_mapping", + _valid_submit_args(), + tc_id="right", + ) + ] + ), + ] + ) + _patch_llm(monkeypatch, fake) + + result = run_entity_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_class=_ontology_class(), + source_model_slice=_source_model_slice(), + ) + + assert result.success is True + assert result.iterations == 2 + assert result.mapping is not None + assert result.mapping["ontology_class"] == requested_uri + + # The LLM's second call must have seen a corrective tool message + # describing the mismatch, surfaced through ``messages``. + assert fake.last_messages is not None + tool_messages = [m for m in fake.last_messages if m.get("role") == "tool"] + assert tool_messages, "expected at least one tool message on the 2nd call" + corrective = tool_messages[-1] + corrective_content = corrective.get("content", "") + assert other_uri in corrective_content + assert requested_uri in corrective_content + assert "does not match" in corrective_content + # Sanity: the corrective payload is a JSON error (not the original + # success=True response). + parsed = json.loads(corrective_content) + assert parsed.get("success") is False + assert "error" in parsed + + +def test_records_steps(monkeypatch, no_sleep): + """Every tool-calling iteration produces exactly one ``tool_call`` step + immediately followed by one ``tool_result`` step with the same tool_name. + """ + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "sample_table", + {"full_name": "cat.sch.mothers"}, + tc_id="a", + ) + ] + ), + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_entity_mapping", _valid_submit_args(), tc_id="b" + ) + ] + ), + ] + ) + _patch_llm(monkeypatch, fake) + + class FakeClient: + def execute_query(self, sql): + return [{"nhs_number": "1234567890"}] + + result = run_entity_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=FakeClient(), + ontology_class=_ontology_class(), + source_model_slice=_source_model_slice(), + ) + + assert result.success is True + assert len(result.steps) % 2 == 0 + for i in range(0, len(result.steps), 2): + call_step = result.steps[i] + result_step = result.steps[i + 1] + assert call_step.step_type == "tool_call" + assert result_step.step_type == "tool_result" + assert call_step.tool_name == result_step.tool_name + assert call_step.content != "" + assert result_step.content != "" + assert isinstance(call_step, EntityGenStep) + assert isinstance(result_step, EntityGenStep) diff --git a/tests/agents/agent_mapping_pge/test_planner.py b/tests/agents/agent_mapping_pge/test_planner.py new file mode 100644 index 00000000..7af091b9 --- /dev/null +++ b/tests/agents/agent_mapping_pge/test_planner.py @@ -0,0 +1,519 @@ +"""Tests for the mapping-PGE Planner agent (Sprint 3). + +The Planner is a tool-calling ReAct loop terminated by ``submit_source_model``. +These tests exercise the loop's control flow with a *fake LLM* — a stub that +replaces ``call_serving_endpoint`` at module level and returns canned tool- +call responses on a per-call basis. + +No real HTTP, no real Databricks, no MLflow tracing. The tracing decorator +is a no-op when MLflow isn't configured (see ``_TRACING_READY`` in +``agents.tracing``), so it runs cleanly here. + +What we DO exercise: +* The four termination conditions + — terminal submit_source_model with success=True breaks the loop + — text content with no tool calls is treated as failure + — iteration budget exhaustion is treated as failure + — submit returning success=False is NOT terminal (allows retry) +* Step recording: every tool call produces both tool_call and tool_result + steps in the right order. +* Iteration counter accuracy. + +What we do NOT exercise (covered elsewhere or out of scope): +* The actual content of the SourceModel — that's Sprint 1's contracts tests. +* The four planner tool handlers — that's Sprint 2's test_planner_tools.py. +* MLflow tracing semantics — the decorator is wrapped in an ``if`` guard. +""" + +import json +from typing import Any, Callable, Dict, List, Optional + +import pytest + +from agents.agent_mapping_pge import planner as planner_mod +from agents.agent_mapping_pge.contracts import SourceModel +from agents.agent_mapping_pge.planner import ( + PlannerResult, + PlannerStep, + run_planner, +) + + +# ===================================================== +# Fake LLM scaffolding +# ===================================================== + + +def _make_tool_call(name: str, arguments: dict, *, tc_id: str = "tc1") -> dict: + """Build an OpenAI-style tool_calls entry.""" + return { + "id": tc_id, + "type": "function", + "function": {"name": name, "arguments": json.dumps(arguments)}, + } + + +def _llm_response( + *, + tool_calls: Optional[List[dict]] = None, + content: Optional[str] = None, + finish_reason: str = "tool_calls", + usage: Optional[Dict[str, int]] = None, +) -> dict: + """Build a minimal OpenAI-style chat-completions response.""" + message: Dict[str, Any] = {"role": "assistant"} + if tool_calls: + message["tool_calls"] = tool_calls + if content is not None: + message["content"] = content + return { + "choices": [{"finish_reason": finish_reason, "message": message}], + "usage": usage or {"prompt_tokens": 10, "completion_tokens": 5}, + } + + +class FakeLLM: + """A stub for ``call_serving_endpoint`` that returns canned responses. + + The list is consumed front-to-back, one response per call. If a test + exhausts the list, the stub raises — that's almost always a test bug + (the loop iterated more times than the test author expected). + """ + + def __init__(self, responses: List[dict]): + self.responses = list(responses) + self.calls = 0 + + def __call__(self, *args, **kwargs) -> dict: + self.calls += 1 + if not self.responses: + raise AssertionError( + f"FakeLLM: ran out of canned responses on call #{self.calls}" + ) + return self.responses.pop(0) + + +class CyclingFakeLLM: + """Like FakeLLM but cycles through a fixed list forever. + + Used for the iteration-budget-exhaustion test, where the LLM is supposed + to be stuck in an infinite loop until the engine cuts it off. + """ + + def __init__(self, responses: List[dict]): + self.responses = list(responses) + self.calls = 0 + + def __call__(self, *args, **kwargs) -> dict: + resp = self.responses[self.calls % len(self.responses)] + self.calls += 1 + return resp + + +@pytest.fixture +def no_sleep(monkeypatch): + """Neutralise the 3-second inter-iteration delay so tests run fast.""" + monkeypatch.setattr(planner_mod.time, "sleep", lambda *_a, **_k: None) + + +def _patch_llm(monkeypatch, fake: Callable[..., dict]) -> None: + """Replace the planner's reference to ``call_serving_endpoint``.""" + monkeypatch.setattr(planner_mod, "call_serving_endpoint", fake) + + +# ===================================================== +# Fixtures: a minimal valid SourceModel payload +# ===================================================== + + +def _valid_source_model_dict() -> Dict[str, Any]: + """Same shape as test_planner_tools._valid_source_model_dict — kept + independent here so the two test files don't coupling-leak.""" + return { + "table_roles": [ + { + "table": "cat.sch.mothers", + "ontology_class_candidates": [ + { + "uri": "http://ex.org/maternity#Mother", + "confidence": 0.9, + "reason": "row per NHS", + } + ], + } + ], + "canonical_ids": [ + { + "ontology_class": "http://ex.org/maternity#Mother", + "canonical_column_per_table": {"cat.sch.mothers": "nhs_number"}, + "format_note": "", + } + ], + "join_keys": [], + "mapping_plan": { + "entity_order": ["http://ex.org/maternity#Mother"], + "relationship_order": [], + "skip": [], + }, + } + + +def _minimal_metadata() -> dict: + return { + "tables": [ + { + "name": "mothers", + "full_name": "cat.sch.mothers", + "columns": [ + {"name": "nhs_number", "type": "STRING"}, + {"name": "dob", "type": "DATE"}, + ], + } + ] + } + + +def _minimal_ontology() -> dict: + return { + "entities": [{"name": "Mother", "uri": "http://ex.org/maternity#Mother"}], + "relationships": [], + } + + +# ===================================================== +# 1. Single-shot submit terminates immediately +# ===================================================== + + +def test_planner_terminates_on_submit_source_model(monkeypatch, no_sleep): + """First LLM turn calls submit_source_model with a valid model — Planner + must return success=True with iterations=1 and source_model populated. + """ + sm = _valid_source_model_dict() + fake = FakeLLM( + [ + _llm_response( + tool_calls=[_make_tool_call("submit_source_model", {"model": sm})] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = run_planner( + host="https://x", + token="t", + endpoint_name="ep", + client=None, # not used in this scenario + metadata=_minimal_metadata(), + ontology=_minimal_ontology(), + ) + + assert isinstance(result, PlannerResult) + assert result.success is True + assert result.iterations == 1 + assert isinstance(result.source_model, SourceModel) + assert len(result.source_model.table_roles) == 1 + assert result.error == "" + assert result.usage["prompt_tokens"] >= 0 + # Exactly one tool_call + one tool_result step. + step_kinds = [s.step_type for s in result.steps] + assert step_kinds == ["tool_call", "tool_result"] + assert result.steps[0].tool_name == "submit_source_model" + assert result.steps[1].tool_name == "submit_source_model" + + +# ===================================================== +# 2. Multi-step ReAct trajectory followed by submit +# ===================================================== + + +def test_planner_multi_step_then_submit(monkeypatch, no_sleep): + """get_metadata → get_ontology → sample_table → submit_source_model.""" + sm = _valid_source_model_dict() + fake = FakeLLM( + [ + _llm_response( + tool_calls=[_make_tool_call("get_metadata", {}, tc_id="a")] + ), + _llm_response( + tool_calls=[_make_tool_call("get_ontology", {}, tc_id="b")] + ), + _llm_response( + tool_calls=[ + _make_tool_call( + "sample_table", + {"full_name": "cat.sch.mothers"}, + tc_id="c", + ) + ] + ), + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_source_model", {"model": sm}, tc_id="d" + ) + ] + ), + ] + ) + _patch_llm(monkeypatch, fake) + + # sample_table needs a client — return one row. + class FakeClient: + def execute_query(self, sql): + return [{"nhs_number": "1234567890", "dob": "1990-01-01"}] + + result = run_planner( + host="https://x", + token="t", + endpoint_name="ep", + client=FakeClient(), + metadata=_minimal_metadata(), + ontology=_minimal_ontology(), + ) + + assert result.success is True + assert result.iterations == 4 + assert isinstance(result.source_model, SourceModel) + + # Every iteration produces both a tool_call and a tool_result step. + assert len(result.steps) == 8 + expected_tool_names = [ + "get_metadata", + "get_metadata", + "get_ontology", + "get_ontology", + "sample_table", + "sample_table", + "submit_source_model", + "submit_source_model", + ] + assert [s.tool_name for s in result.steps] == expected_tool_names + assert [s.step_type for s in result.steps] == [ + "tool_call", + "tool_result", + "tool_call", + "tool_result", + "tool_call", + "tool_result", + "tool_call", + "tool_result", + ] + + +# ===================================================== +# 3. submit returning success=False does NOT terminate +# ===================================================== + + +def test_planner_invalid_source_model_does_not_terminate(monkeypatch, no_sleep): + """First submit is malformed (missing 'table' on a table_role) — the + tool returns success=False and the Planner keeps going. Second submit + is valid and terminates the loop. + """ + bad = _valid_source_model_dict() + del bad["table_roles"][0]["table"] # break it + + good = _valid_source_model_dict() + # Make the good one visibly different so we can prove which one stuck. + good["mapping_plan"]["entity_order"] = ["http://ex.org/maternity#Mother"] + + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call("submit_source_model", {"model": bad}, tc_id="x") + ] + ), + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_source_model", {"model": good}, tc_id="y" + ) + ] + ), + ] + ) + _patch_llm(monkeypatch, fake) + + result = run_planner( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + metadata=_minimal_metadata(), + ontology=_minimal_ontology(), + ) + + assert result.success is True + assert result.iterations == 2 + assert isinstance(result.source_model, SourceModel) + # The valid one is what landed on ctx — pull a field from it. + assert result.source_model.mapping_plan.entity_order == [ + "http://ex.org/maternity#Mother" + ] + # Both submit attempts were recorded; the first tool_result must signal + # failure so the orchestrator can attribute the retry. + first_submit_result = result.steps[1] + assert first_submit_result.step_type == "tool_result" + assert first_submit_result.tool_name == "submit_source_model" + payload = json.loads(first_submit_result.content) + assert payload["success"] is False + + +# ===================================================== +# 4. Free-text output without a terminal tool call → failure +# ===================================================== + + +def test_planner_text_without_terminal_fails(monkeypatch, no_sleep): + """The Planner must terminate via submit_source_model. A plain-text + response is treated as failure. + """ + fake = FakeLLM( + [_llm_response(content="I think we are done.", finish_reason="stop")] + ) + _patch_llm(monkeypatch, fake) + + result = run_planner( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + metadata=_minimal_metadata(), + ontology=_minimal_ontology(), + ) + + assert result.success is False + assert result.iterations == 1 + assert result.source_model is None + assert "without submitting source model" in result.error + # The text was recorded as an output step for debuggability. + assert any(s.step_type == "output" for s in result.steps) + + +# ===================================================== +# 5. Iteration budget exhaustion → failure +# ===================================================== + + +def test_planner_exhausts_iteration_budget(monkeypatch, no_sleep): + """Fake LLM keeps calling get_metadata forever. With max_iterations=3 + the Planner must give up cleanly and report budget exhaustion. + """ + fake = CyclingFakeLLM( + [ + _llm_response( + tool_calls=[_make_tool_call("get_metadata", {}, tc_id="a")] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = run_planner( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + metadata=_minimal_metadata(), + ontology=_minimal_ontology(), + max_iterations=3, + ) + + assert result.success is False + assert result.iterations == 3 + assert result.source_model is None + assert "iteration budget" in result.error + # Three iterations × (tool_call + tool_result) = 6 steps. + assert len(result.steps) == 6 + + +# ===================================================== +# 6. Step recording invariants +# ===================================================== + + +def test_planner_records_steps(monkeypatch, no_sleep): + """For each tool-calling iteration, the Planner must record exactly one + ``tool_call`` step (with non-empty arguments-as-content) and one + ``tool_result`` step (with non-empty content) — in that order, paired by + ``tool_name``. + """ + sm = _valid_source_model_dict() + fake = FakeLLM( + [ + _llm_response( + tool_calls=[_make_tool_call("get_metadata", {}, tc_id="a")] + ), + _llm_response( + tool_calls=[ + _make_tool_call("submit_source_model", {"model": sm}, tc_id="b") + ] + ), + ] + ) + _patch_llm(monkeypatch, fake) + + result = run_planner( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + metadata=_minimal_metadata(), + ontology=_minimal_ontology(), + ) + + assert result.success is True + # Verify the pairing: every odd-indexed step (tool_call) is immediately + # followed by an even-indexed step (tool_result) with the same tool_name. + assert len(result.steps) % 2 == 0 + for i in range(0, len(result.steps), 2): + call_step = result.steps[i] + result_step = result.steps[i + 1] + assert call_step.step_type == "tool_call" + assert result_step.step_type == "tool_result" + assert call_step.tool_name == result_step.tool_name + assert call_step.content != "" + assert result_step.content != "" + # PlannerStep is the right type. + assert isinstance(call_step, PlannerStep) + assert isinstance(result_step, PlannerStep) + + +# ===================================================== +# Prompt contract — canonical-key normalization guidance +# ===================================================== + + +class TestCanonicalKeyNormalizationPrompt: + """Pin the load-bearing canonical-key guidance in the system prompt. + + Issue 2 root cause: the Planner left cross-trust keys disjoint (0% + overlap rationalized as "trust-scoped"), and when it did normalize it + copied a non-anchored regex that returns a leading-dash key. These + assertions keep the corrective guidance from silently regressing. + """ + + def test_offers_expression_overlap_verification_tool(self): + assert "normalized_value_overlap" in planner_mod.SYSTEM_PROMPT + + def test_zero_overlap_is_not_a_terminal_state(self): + prompt = planner_mod.SYSTEM_PROMPT + # The prompt must steer the model AWAY from accepting disjoint keys. + assert "trust-scoped" in prompt # names the trap explicitly + assert "100%" in prompt and "dangle" in prompt + + def test_regex_example_is_anchored(self): + prompt = planner_mod.SYSTEM_PROMPT + # The correct, anchored pattern must be present... + assert "[a-f0-9][a-f0-9-]+-preg-" in prompt + # ...and it must be flagged as the RIGHT one (the WRONG/RIGHT contrast + # teaches the leading-dash pitfall). + assert "✓ RIGHT" in prompt and "✗ WRONG" in prompt + + def test_derived_key_extracts_core_before_suffix(self): + prompt = planner_mod.SYSTEM_PROMPT + # Derived child keys must extract the shared core, then append suffix — + # not concat onto the raw prefixed local id. + assert "regexp_extract" in prompt + assert "-del" in prompt # the worked Delivery example diff --git a/tests/agents/agent_mapping_pge/test_relationship_generator.py b/tests/agents/agent_mapping_pge/test_relationship_generator.py new file mode 100644 index 00000000..d18e4a71 --- /dev/null +++ b/tests/agents/agent_mapping_pge/test_relationship_generator.py @@ -0,0 +1,736 @@ +"""Tests for the mapping-PGE RelationshipGenerator agent (Sprint 5). + +Mirrors the structure of ``test_entity_generator.py``. The Generator is a +narrow tool-calling ReAct loop terminated by ``submit_relationship_mapping``. +These tests exercise the loop's control flow with a *fake LLM* — a stub that +replaces ``call_serving_endpoint`` at module level and returns canned +responses on a per-call basis. + +No real HTTP, no real Databricks, no MLflow tracing. + +What we DO exercise: +* Termination on a single submit call. +* Multi-step trajectory (execute_sql → submit). +* Text-only output is treated as failure (no terminal call). +* Iteration-budget exhaustion is treated as failure. +* ``retry_hint`` surfaces inside the user message. +* Strict ``property_uri`` match — submit with a wrong URI is coached, not + terminal. +* Step recording invariants. +* The user prompt surfaces the source/target id_columns verbatim — pins the + Sprint 5 contract that the LLM sees the endpoint columns. +""" + +import json +from typing import Any, Callable, Dict, List, Optional + +import pytest + +from agents.agent_mapping_pge.generators import relationship as rel_mod +from agents.agent_mapping_pge.generators.relationship import ( + RelationshipGenResult, + RelationshipGenStep, + run_relationship_generator, +) + + +# ===================================================== +# Fake LLM scaffolding (mirrors test_entity_generator.py) +# ===================================================== + + +_PROP_URI = "http://ex.org/maternity#motherOf" +_SOURCE_CLASS = "http://ex.org/maternity#Mother" +_TARGET_CLASS = "http://ex.org/maternity#Baby" + + +def _make_tool_call(name: str, arguments: dict, *, tc_id: str = "tc1") -> dict: + return { + "id": tc_id, + "type": "function", + "function": {"name": name, "arguments": json.dumps(arguments)}, + } + + +def _llm_response( + *, + tool_calls: Optional[List[dict]] = None, + content: Optional[str] = None, + finish_reason: str = "tool_calls", + usage: Optional[Dict[str, int]] = None, +) -> dict: + message: Dict[str, Any] = {"role": "assistant"} + if tool_calls: + message["tool_calls"] = tool_calls + if content is not None: + message["content"] = content + return { + "choices": [{"finish_reason": finish_reason, "message": message}], + "usage": usage or {"prompt_tokens": 10, "completion_tokens": 5}, + } + + +class FakeLLM: + def __init__(self, responses: List[dict]): + self.responses = list(responses) + self.calls = 0 + self.last_messages: Optional[List[dict]] = None + self.first_messages: Optional[List[dict]] = None + + def __call__(self, *args, **kwargs) -> dict: + self.calls += 1 + msgs: Optional[List[dict]] = None + if len(args) >= 4 and isinstance(args[3], list): + msgs = args[3] + elif "messages" in kwargs: + msgs = kwargs["messages"] + if msgs is not None: + snapshot = [dict(m) for m in msgs] + if self.first_messages is None: + self.first_messages = snapshot + self.last_messages = snapshot + + if not self.responses: + raise AssertionError( + f"FakeLLM: ran out of canned responses on call #{self.calls}" + ) + return self.responses.pop(0) + + +class CyclingFakeLLM: + """Like FakeLLM but cycles through a fixed list forever.""" + + def __init__(self, responses: List[dict]): + self.responses = list(responses) + self.calls = 0 + + def __call__(self, *args, **kwargs) -> dict: + resp = self.responses[self.calls % len(self.responses)] + self.calls += 1 + return resp + + +@pytest.fixture +def no_sleep(monkeypatch): + """Neutralise the 3-second inter-iteration delay so tests run fast.""" + monkeypatch.setattr(rel_mod.time, "sleep", lambda *_a, **_k: None) + + +def _patch_llm(monkeypatch, fake: Callable[..., dict]) -> None: + monkeypatch.setattr(rel_mod, "call_serving_endpoint", fake) + + +# ===================================================== +# Fixtures +# ===================================================== + + +def _ontology_property() -> dict: + return { + "uri": _PROP_URI, + "label": "motherOf", + "name": "motherOf", + "comment": "Links a Mother to each of her babies.", + "domain": _SOURCE_CLASS, + "range": _TARGET_CLASS, + } + + +def _source_entity_mapping() -> dict: + return { + "ontology_class": _SOURCE_CLASS, + "class_name": "Mother", + "id_column": "nhs_number", + "label_column": "nhs_number", + "sql_query": ( + "SELECT nhs_number AS ID, nhs_number AS Label FROM cat.sch.mothers " + "WHERE nhs_number IS NOT NULL" + ), + } + + +def _target_entity_mapping() -> dict: + return { + "ontology_class": _TARGET_CLASS, + "class_name": "Baby", + "id_column": "baby_id", + "label_column": "baby_id", + "sql_query": ( + "SELECT baby_id AS ID, baby_id AS Label FROM cat.sch.babies " + "WHERE baby_id IS NOT NULL" + ), + } + + +def _source_model_slice() -> dict: + return { + "relevant_joins": [ + { + "from_ref": "cat.sch.babies.mother_nhs_number", + "to_ref": "cat.sch.mothers.nhs_number", + "confidence": 0.95, + "overlap_pct": 0.98, + "kind": "same_trust_fk", + } + ], + "candidate_tables": [ + {"table": "cat.sch.babies", "reason": "row per baby, has mother FK"} + ], + } + + +def _valid_submit_args() -> dict: + return { + "property_uri": _PROP_URI, + "property_name": "motherOf", + "sql_query": ( + "SELECT mother_nhs_number AS source_id, baby_id AS target_id " + "FROM cat.sch.babies WHERE mother_nhs_number IS NOT NULL" + ), + "source_id_column": "nhs_number", + "target_id_column": "baby_id", + "domain": _SOURCE_CLASS, + "range_class": _TARGET_CLASS, + "direction": "forward", + } + + +# ===================================================== +# 1. Single-shot submit terminates immediately +# ===================================================== + + +def test_terminates_on_submit(monkeypatch, no_sleep): + """First LLM turn submits a valid mapping → success, iterations=1.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_relationship_mapping", _valid_submit_args() + ) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = run_relationship_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_property=_ontology_property(), + source_entity_mapping=_source_entity_mapping(), + target_entity_mapping=_target_entity_mapping(), + source_model_slice=_source_model_slice(), + ) + + assert isinstance(result, RelationshipGenResult) + assert result.success is True + assert result.iterations == 1 + assert result.mapping is not None + assert result.mapping["property"] == _PROP_URI + assert result.mapping["source_id_column"] == "nhs_number" + assert result.mapping["target_id_column"] == "baby_id" + assert result.error == "" + step_kinds = [s.step_type for s in result.steps] + assert step_kinds == ["tool_call", "tool_result"] + assert result.steps[0].tool_name == "submit_relationship_mapping" + + +# ===================================================== +# 2. execute_sql validation, then submit +# ===================================================== + + +def test_validates_sql_then_submits(monkeypatch, no_sleep): + """execute_sql → submit_relationship_mapping → success, iterations=2.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "execute_sql", + { + "sql": ( + "SELECT mother_nhs_number AS source_id, baby_id " + "AS target_id FROM cat.sch.babies " + "WHERE mother_nhs_number IS NOT NULL" + ) + }, + tc_id="a", + ) + ] + ), + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_relationship_mapping", + _valid_submit_args(), + tc_id="b", + ) + ] + ), + ] + ) + _patch_llm(monkeypatch, fake) + + class FakeClient: + def execute_query(self, sql): + return [{"source_id": "1234567890", "target_id": "b-1"}] + + result = run_relationship_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=FakeClient(), + ontology_property=_ontology_property(), + source_entity_mapping=_source_entity_mapping(), + target_entity_mapping=_target_entity_mapping(), + source_model_slice=_source_model_slice(), + ) + + assert result.success is True + assert result.iterations == 2 + assert result.mapping is not None + # Sequence: tool_call(execute_sql), tool_result(execute_sql), + # tool_call(submit), tool_result(submit) — 4 steps. + assert len(result.steps) == 4 + assert [s.tool_name for s in result.steps] == [ + "execute_sql", + "execute_sql", + "submit_relationship_mapping", + "submit_relationship_mapping", + ] + + +# ===================================================== +# 3. Text without terminal call → failure +# ===================================================== + + +def test_text_without_terminal_fails(monkeypatch, no_sleep): + """A plain-text response is treated as failure — the Generator must + terminate via submit_relationship_mapping. + """ + fake = FakeLLM( + [_llm_response(content="I am thinking…", finish_reason="stop")] + ) + _patch_llm(monkeypatch, fake) + + result = run_relationship_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_property=_ontology_property(), + source_entity_mapping=_source_entity_mapping(), + target_entity_mapping=_target_entity_mapping(), + source_model_slice=_source_model_slice(), + ) + + assert result.success is False + assert result.iterations == 1 + assert result.mapping is None + assert "without submitting mapping" in result.error + assert any(s.step_type == "output" for s in result.steps) + + +# ===================================================== +# 4. Iteration-budget exhaustion → failure +# ===================================================== + + +def test_exhausts_iteration_budget(monkeypatch, no_sleep): + """Endless sample_table calls with max_iterations=3 → fail with + ``iteration budget`` and three iterations of steps recorded.""" + fake = CyclingFakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "sample_table", + {"full_name": "cat.sch.babies"}, + tc_id="a", + ) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + class FakeClient: + def execute_query(self, sql): + return [{"mother_nhs_number": "1234567890", "baby_id": "b-1"}] + + result = run_relationship_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=FakeClient(), + ontology_property=_ontology_property(), + source_entity_mapping=_source_entity_mapping(), + target_entity_mapping=_target_entity_mapping(), + source_model_slice=_source_model_slice(), + max_iterations=3, + ) + + assert result.success is False + assert result.iterations == 3 + assert result.mapping is None + assert "iteration budget" in result.error + # 3 iterations × (tool_call + tool_result) = 6 steps. + assert len(result.steps) == 6 + + +# ===================================================== +# 5. retry_hint surfaces in the user prompt +# ===================================================== + + +def test_retry_hint_surfaces_in_user_prompt(monkeypatch, no_sleep): + """If ``retry_hint`` is provided, the FIRST LLM call's user message must + contain the hint verbatim and the RETRY HINT label.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_relationship_mapping", _valid_submit_args() + ) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = run_relationship_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_property=_ontology_property(), + source_entity_mapping=_source_entity_mapping(), + target_entity_mapping=_target_entity_mapping(), + source_model_slice=_source_model_slice(), + retry_hint="Use mother_nhs_number, not patient_id.", + ) + + assert result.success is True + assert fake.first_messages is not None + assert fake.first_messages[0]["role"] == "system" + assert fake.first_messages[1]["role"] == "user" + user_content = fake.first_messages[1]["content"] + assert "Use mother_nhs_number, not patient_id." in user_content + assert "RETRY HINT" in user_content + # Retry-hint corrective workflow surfaces the dangling-edge probe. + assert "dangling-edge probe" in user_content + assert "DO NOT repeat the same column choice" in user_content + + +def test_system_prompt_mandates_dangling_edge_self_check(monkeypatch, no_sleep): + """The system prompt must instruct the model to run a dangling-edge + probe with execute_sql BEFORE submitting — name-similarity alone is + insufficient and was the root cause of the live smoke failure on + hasapgarscore.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_relationship_mapping", _valid_submit_args() + ) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + run_relationship_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_property=_ontology_property(), + source_entity_mapping=_source_entity_mapping(), + target_entity_mapping=_target_entity_mapping(), + source_model_slice=_source_model_slice(), + ) + + system_content = fake.first_messages[0]["content"] + assert "SELF-VERIFY THE VALUES BEFORE SUBMITTING" in system_content + assert "dangling_src" in system_content + assert "dangling_tgt" in system_content + # The probe must reference both endpoint universes via the entity SQLs. + assert "source entity's SQL" in system_content + assert "target entity's SQL" in system_content + + +def test_system_prompt_teaches_reproducing_derived_id_expression(monkeypatch, no_sleep): + """The id_column is an alias for a derived canonical expression; the + prompt must instruct the model to REPRODUCE that expression for the + endpoints (not select a raw column) — the root cause of the 100% + source-dangling on hasapgarscore/deliveredbaby in the live smoke. + """ + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_relationship_mapping", _valid_submit_args() + ) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + run_relationship_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_property=_ontology_property(), + source_entity_mapping=_source_entity_mapping(), + target_entity_mapping=_target_entity_mapping(), + source_model_slice=_source_model_slice(), + ) + + system_content = fake.first_messages[0]["content"] + assert "ALIAS FOR A DERIVED EXPRESSION" in system_content + assert "regexp_extract" in system_content + # Must steer away from selecting a raw column for the endpoint. + assert "reproduce" in system_content.lower() + + +def test_system_prompt_teaches_shared_coverage_table_rule(monkeypatch, no_sleep): + """Cross-trust endpoint entities only overlap on shared trusts; building + the edge from a table only one entity covers yields 100% dangling on the + other side. The prompt must teach picking a BOTH-covered source table. + """ + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_relationship_mapping", _valid_submit_args() + ) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + run_relationship_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_property=_ontology_property(), + source_entity_mapping=_source_entity_mapping(), + target_entity_mapping=_target_entity_mapping(), + source_model_slice=_source_model_slice(), + ) + + system_content = fake.first_messages[0]["content"] + assert "BOTH" in system_content + assert "coverage" in system_content.lower() + assert "100% source-dangling" in system_content or "100%" in system_content + + +# ===================================================== +# 6. Wrong property_uri submission does NOT terminate +# ===================================================== + + +def test_wrong_property_uri_submission_does_not_terminate(monkeypatch, no_sleep): + """A submit_relationship_mapping call with a property_uri that doesn't + match the requested one must NOT terminate the loop. The Generator must + keep going so a follow-up submit (with the correct URI) can succeed, and + the LLM must see a corrective tool message describing the mismatch. + """ + requested_uri = _PROP_URI + other_uri = "http://ex.org/maternity#fatherOf" + + wrong_args = _valid_submit_args() + wrong_args["property_uri"] = other_uri + wrong_args["property_name"] = "fatherOf" + + fake = FakeLLM( + [ + # Turn 1: submit with the WRONG property_uri — must NOT terminate. + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_relationship_mapping", + wrong_args, + tc_id="wrong", + ) + ] + ), + # Turn 2: submit with the correct property_uri — should terminate. + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_relationship_mapping", + _valid_submit_args(), + tc_id="right", + ) + ] + ), + ] + ) + _patch_llm(monkeypatch, fake) + + result = run_relationship_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_property=_ontology_property(), + source_entity_mapping=_source_entity_mapping(), + target_entity_mapping=_target_entity_mapping(), + source_model_slice=_source_model_slice(), + ) + + assert result.success is True + assert result.iterations == 2 + assert result.mapping is not None + assert result.mapping["property"] == requested_uri + + # The LLM's second call must have seen a corrective tool message + # describing the mismatch, surfaced through ``messages``. + assert fake.last_messages is not None + tool_messages = [m for m in fake.last_messages if m.get("role") == "tool"] + assert tool_messages, "expected at least one tool message on the 2nd call" + corrective = tool_messages[-1] + corrective_content = corrective.get("content", "") + assert other_uri in corrective_content + assert requested_uri in corrective_content + assert "does not match" in corrective_content + # Sanity: the corrective payload is a JSON error (not the original + # success=True response). + parsed = json.loads(corrective_content) + assert parsed.get("success") is False + assert "error" in parsed + + +# ===================================================== +# 7. Step recording invariants +# ===================================================== + + +def test_records_steps(monkeypatch, no_sleep): + """Every tool-calling iteration produces exactly one ``tool_call`` step + immediately followed by one ``tool_result`` step with the same tool_name. + """ + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "sample_table", + {"full_name": "cat.sch.babies"}, + tc_id="a", + ) + ] + ), + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_relationship_mapping", + _valid_submit_args(), + tc_id="b", + ) + ] + ), + ] + ) + _patch_llm(monkeypatch, fake) + + class FakeClient: + def execute_query(self, sql): + return [{"mother_nhs_number": "1234567890", "baby_id": "b-1"}] + + result = run_relationship_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=FakeClient(), + ontology_property=_ontology_property(), + source_entity_mapping=_source_entity_mapping(), + target_entity_mapping=_target_entity_mapping(), + source_model_slice=_source_model_slice(), + ) + + assert result.success is True + assert len(result.steps) % 2 == 0 + for i in range(0, len(result.steps), 2): + call_step = result.steps[i] + result_step = result.steps[i + 1] + assert call_step.step_type == "tool_call" + assert result_step.step_type == "tool_result" + assert call_step.tool_name == result_step.tool_name + assert call_step.content != "" + assert result_step.content != "" + assert isinstance(call_step, RelationshipGenStep) + assert isinstance(result_step, RelationshipGenStep) + + +# ===================================================== +# 8. User prompt surfaces source/target id_columns +# ===================================================== + + +def test_user_prompt_includes_source_and_target_id_columns(monkeypatch, no_sleep): + """The FIRST call's user message must contain both id_column names + verbatim. This pins the Sprint 5 contract that the Generator surfaces + the endpoint columns to the LLM, so the LLM cannot silently pick + different endpoints. + """ + # Use distinctive id_column names that won't appear anywhere else in + # the slice (mothers/babies join etc.), to make the assertion strict. + src_em = { + "ontology_class": _SOURCE_CLASS, + "class_name": "Mother", + "id_column": "weirdly_named_mother_pk", + "label_column": "weirdly_named_mother_pk", + "sql_query": "SELECT weirdly_named_mother_pk AS ID FROM cat.sch.mothers", + } + tgt_em = { + "ontology_class": _TARGET_CLASS, + "class_name": "Baby", + "id_column": "weirdly_named_baby_pk", + "label_column": "weirdly_named_baby_pk", + "sql_query": "SELECT weirdly_named_baby_pk AS ID FROM cat.sch.babies", + } + + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_relationship_mapping", _valid_submit_args() + ) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + run_relationship_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_property=_ontology_property(), + source_entity_mapping=src_em, + target_entity_mapping=tgt_em, + source_model_slice=_source_model_slice(), + ) + + assert fake.first_messages is not None + user_content = fake.first_messages[1]["content"] + assert "weirdly_named_mother_pk" in user_content + assert "weirdly_named_baby_pk" in user_content From b3e11d05adae608c027b8a9dfd70ddbb7fc5fc2c Mon Sep 17 00:00:00 2001 From: Fiifi Botchway Date: Thu, 25 Jun 2026 12:45:46 +0100 Subject: [PATCH 3/4] feat(agents): add Agent Bricks Supervisor for engine selection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add agent_supervisor — a Databricks Agent Bricks Multi-Agent Supervisor that deterministically assesses a domain's complexity and routes entity/relationship mapping to the PGE engine (agent_mapping_pge) or the original simple engine (agent_auto_assignment). Hybrid routing: a deterministic UC function (assess_domain_complexity) yields the hard recommendation; the supervisor's NL instructions act on it. - complexity.py: weighted deterministic scorer (tables/columns/classes/rels + cross-source key-sharing + schema heterogeneity); reuses pge_eval.normalize. - engine.py: SupervisorEngine assess -> select -> dispatch via AgentClient. - responses_agent.py: per-engine MLflow ResponsesAgent (assess/run modes). - mas.py: SupervisorProvisioner.build_config (pure) + provision; wires the UC function + two engine endpoints with NL routing instructions. - uc_function.sql: self-contained mirror of complexity.py (parity-tested). - SPEC.md + 20-example eval dataset + scripts/provision_supervisor.py. - Tests: 35 (baseline routing 20/20, Python<->SQL parity); 759 stacked green. Stacked on the ontology-PGE and mapping-PGE PRs. Co-authored-by: Isaac --- .planning/agents/agent_supervisor/SPEC.md | 81 ++++++ changelogs/v0.5.2/FiifiB_2026-06-25.log | 45 ++++ scripts/provision_supervisor.py | 102 ++++++++ src/agents/agent_supervisor/__init__.py | 25 ++ src/agents/agent_supervisor/complexity.py | 244 ++++++++++++++++++ src/agents/agent_supervisor/engine.py | 187 ++++++++++++++ src/agents/agent_supervisor/log_model.py | 63 +++++ src/agents/agent_supervisor/mas.py | 174 +++++++++++++ .../agent_supervisor/responses_agent.py | 122 +++++++++ src/agents/agent_supervisor/uc_function.sql | 109 ++++++++ tests/agents/agent_supervisor/__init__.py | 0 .../agent_supervisor/test_complexity.py | 134 ++++++++++ tests/agents/agent_supervisor/test_engine.py | 165 ++++++++++++ .../datasets/agent_supervisor/baseline.jsonl | 20 ++ 14 files changed, 1471 insertions(+) create mode 100644 .planning/agents/agent_supervisor/SPEC.md create mode 100644 scripts/provision_supervisor.py create mode 100644 src/agents/agent_supervisor/__init__.py create mode 100644 src/agents/agent_supervisor/complexity.py create mode 100644 src/agents/agent_supervisor/engine.py create mode 100644 src/agents/agent_supervisor/log_model.py create mode 100644 src/agents/agent_supervisor/mas.py create mode 100644 src/agents/agent_supervisor/responses_agent.py create mode 100644 src/agents/agent_supervisor/uc_function.sql create mode 100644 tests/agents/agent_supervisor/__init__.py create mode 100644 tests/agents/agent_supervisor/test_complexity.py create mode 100644 tests/agents/agent_supervisor/test_engine.py create mode 100644 tests/eval/datasets/agent_supervisor/baseline.jsonl diff --git a/.planning/agents/agent_supervisor/SPEC.md b/.planning/agents/agent_supervisor/SPEC.md new file mode 100644 index 00000000..3ca0a482 --- /dev/null +++ b/.planning/agents/agent_supervisor/SPEC.md @@ -0,0 +1,81 @@ +# SPEC: agent_supervisor + +> Required by `.cursor/12-ai-feature-lifecycle.mdc`. + +## 1. Purpose + +`agent_supervisor` is a Databricks Agent Bricks Multi-Agent Supervisor (MAS) that +orchestrates OntoBricks entity/relationship mapping. It deterministically scores a +domain's complexity (from source metadata + ontology) and routes the mapping task +to either the heavyweight PGE engine (`agent_mapping_pge`) or the original simple +single-agent engine (`agent_auto_assignment`). The routing decision is computed by +a Unity Catalog function (`assess_domain_complexity`) and acted on via the +supervisor's natural-language instructions. + +## 2. Identity + +| Field | Value | +|---|---| +| `agent_name` | `agent_supervisor` | +| `module_path` | `src/agents/agent_supervisor/` | +| `model_endpoint` | Agent Bricks MAS endpoint (provisioned via `mas.py`) | +| `temperature` | `0.0` (assessment is deterministic; routing is rule-driven) | +| `mlflow_experiment` | `/Shared/ontobricks/agents/supervisor` | + +## 3. Tool surface + +| Tool name | Input | Output | Purpose | +|---|---|---|---| +| `assess_domain_complexity` (UC fn) | `metadata_json`, `ontology_json` | JSON `{score, tier, recommended_engine, signals, rationale}` | Deterministic engine recommendation | +| `pge_mapping` (endpoint) | mapping `custom_inputs` | mapping result + PGE extras | Run `agent_mapping_pge` | +| `simple_mapping` (endpoint) | mapping `custom_inputs` | mapping result | Run `agent_auto_assignment` | + +## 4. Success criteria + +1. A 3-source domain sharing an NHS-number key with ~17 classes is routed to `pge`. +2. A single-table, 2-class domain is routed to `simple`. +3. The supervisor always calls `assess_domain_complexity` before routing and never + overrides its `recommended_engine`. + +## 5. Eval dimensions + +| Dimension | Metric | Threshold | Weight | Judge | +|---|---|---|---|---| +| `routing_accuracy` | predicted engine == expected engine over the baseline set | `0.95` | `0.50` | rule-based (`complexity.assess`) | +| `determinism` | identical input yields identical recommendation across runs | `1.00` | `0.20` | rule-based | +| `assessor_called_first` | supervisor calls `assess_domain_complexity` before any engine | `1.00` | `0.20` | trace inspection | +| `latency_p95` | assessment seconds (excludes the engine run) | `<= 2.0` | `0.10` | wall-clock | + +**Aggregate threshold:** ≥ `0.90` to pass. + +## 6. Failure modes + +| Symptom | Detection | Mitigation | +|---|---|---| +| Supervisor skips the assessor and guesses | trace shows no `assess_domain_complexity` call | strengthen instructions; the assessor verdict is authoritative | +| Complex domain routed to simple engine | `routing_accuracy` drop on cross-source cases | re-tune weights/threshold in `complexity.py` + `uc_function.sql` (keep in sync) | +| UC function / Python drift | `test_uc_function_parity` shared-constant check | edit both files together | + +## 7. Eval dataset + +- **Baseline:** `tests/eval/datasets/agent_supervisor/baseline.jsonl` (≥20 examples; + mix of single-source/simple and multi-source/complex domains with the expected + engine). +- **Regression:** added on first production mis-route. + +## 8. MLflow tracing + +The mapping-engine ResponsesAgents (`responses_agent.py`) trace via the shared +MLflow `ResponsesAgent` plumbing; the assessment is logged at INFO. The MAS +endpoint is traced by Agent Bricks. + +## 9. Plan reference + +`docs/plans/2026-06-25-goal-loop-and-pge-eval-design.md` (PGE family) + the PR-split +plan tracked in session memory. + +## 10. Sign-off + +- [x] Sections 4, 5, 6, 7 filled. +- [ ] Baseline eval run URI pasted into PR body. +- [x] Aggregate threshold declared in §5. diff --git a/changelogs/v0.5.2/FiifiB_2026-06-25.log b/changelogs/v0.5.2/FiifiB_2026-06-25.log index 70b01867..f000a97b 100644 --- a/changelogs/v0.5.2/FiifiB_2026-06-25.log +++ b/changelogs/v0.5.2/FiifiB_2026-06-25.log @@ -113,3 +113,48 @@ package (12 modules) + 2 new tools + 9 test modules; 4 additive modifications - `uv run pytest tests/agents/agent_mapping_pge -q` → **90 passed**. - `uv run pytest tests/units/agents tests/units/mapping -q` → **208 passed**. - Imports resolve on the upstream base (origin/master, v0.5.2). + +# 2026-06-25 — feat(agents): Agent Bricks Supervisor for engine selection + +## Context + +PR1 (ontology PGE) and PR2 (mapping PGE) introduce the heavyweight PGE engines +alongside the retained simple engine. This change adds the orchestration layer: +a Databricks **Agent Bricks Multi-Agent Supervisor (MAS)** that, per domain, +**deterministically** assesses complexity and routes the mapping task to the PGE +engine (`agent_mapping_pge`) or the simple engine (`agent_auto_assignment`). + +Routing is the requested hybrid: a deterministic Unity Catalog function provides +the hard recommendation, and the supervisor's natural-language instructions act +on it. (Stacked on PR1 + PR2.) + +## Changes + +1. NEW `src/agents/agent_supervisor/`: + - `complexity.py` — `ComplexityAssessor`: weighted, deterministic score over + #tables, #columns, #classes, #relationships, cross-source key-sharing, and + schema-naming heterogeneity → tier + recommended engine. Reuses + `pge_eval.normalize` for input parsing. + - `engine.py` — `SupervisorEngine`: assess → select → dispatch via + `AgentClient` (mapping has the genuine PGE-vs-simple choice; ontology uses + the single owl-generator). + - `responses_agent.py` — `MappingEngineResponsesAgent`: MLflow ResponsesAgent + serving one engine per endpoint (`assess`/`run` modes; long runs handled by + the caller as a task). + - `mas.py` — `SupervisorProvisioner.build_config` (pure) + `provision`; the + MAS wires the complexity UC function + the two engine endpoints with NL + routing instructions. + - `uc_function.sql` — `assess_domain_complexity` UC function, a self-contained + mirror of `complexity.py` (constants guarded by `test_uc_function_parity`). + - `log_model.py` — logs both engine endpoints. +2. `scripts/provision_supervisor.py` — end-to-end provisioning orchestration. +3. `.planning/agents/agent_supervisor/SPEC.md` + eval dataset + `tests/eval/datasets/agent_supervisor/baseline.jsonl` (20 examples). +4. Tests: `tests/agents/agent_supervisor/{test_complexity,test_engine}.py`. + +## Tests + +- `uv run pytest tests/agents/agent_supervisor -q` → **35 passed** (incl. baseline + routing-accuracy 20/20 and Python↔SQL constant parity). +- Full stacked-branch regression `tests/agents tests/units/{agents,mapping,pge_eval,ontology}` + → **759 passed, 11 skipped**. diff --git a/scripts/provision_supervisor.py b/scripts/provision_supervisor.py new file mode 100644 index 00000000..ed9c611d --- /dev/null +++ b/scripts/provision_supervisor.py @@ -0,0 +1,102 @@ +"""Provision the OntoBricks mapping Supervisor (Agent Bricks MAS) end to end. + +Run from the repo root after PR1+PR2 land. Steps: + +1. Register the deterministic complexity UC function from ``uc_function.sql`` + (substituting ${CATALOG}/${SCHEMA}). +2. Log + deploy the two mapping-engine ResponsesAgents as Model Serving endpoints. +3. Build the Supervisor (MAS) config and create/update it via Agent Bricks. + +This script does workspace I/O and is intended to run inside a configured +Databricks environment (CLI profile or SP creds). It is deliberately thin — the +testable logic lives in ``agents.agent_supervisor.{complexity,engine,mas}``. + +Usage:: + + CATALOG=fiifi_cdm_demo_catalog SCHEMA=ontobricks \\ + PGE_ENDPOINT=ob-mapping-pge SIMPLE_ENDPOINT=ob-mapping-simple \\ + python scripts/provision_supervisor.py +""" + +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "src")) + +from agents.agent_supervisor.mas import SupervisorProvisioner # noqa: E402 +from back.core.logging import get_logger # noqa: E402 + +logger = get_logger(__name__) + + +def register_uc_function(catalog: str, schema: str, warehouse_id: str) -> None: + """Execute uc_function.sql with the catalog/schema substituted.""" + from databricks import sql as dbsql # local import: deploy-time dep + + sql_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "src", + "agents", + "agent_supervisor", + "uc_function.sql", + ) + with open(sql_path) as fh: + ddl = fh.read().replace("${CATALOG}", catalog).replace("${SCHEMA}", schema) + + host = os.environ["DATABRICKS_HOST"].replace("https://", "") + with dbsql.connect( + server_hostname=host, + http_path=f"/sql/1.0/warehouses/{warehouse_id}", + access_token=os.environ["DATABRICKS_TOKEN"], + ) as conn: + with conn.cursor() as cur: + cur.execute(ddl) + logger.info("Registered %s.%s.assess_domain_complexity", catalog, schema) + + +def deploy_engine_endpoints(experiment: str) -> dict: + """Log + deploy both mapping-engine ResponsesAgents. Returns endpoint names.""" + from agents.agent_supervisor.log_model import log_engine_agent + + endpoints = {} + for engine, env_key, default in ( + ("pge", "PGE_ENDPOINT", "ob-mapping-pge"), + ("simple", "SIMPLE_ENDPOINT", "ob-mapping-simple"), + ): + uri = log_engine_agent(engine, experiment) + endpoint = os.environ.get(env_key, default) + logger.info("Logged %s engine -> %s; deploy as endpoint %r", engine, uri, endpoint) + # Deployment to Model Serving is done via databricks.agents.deploy(uri, + # endpoint) or the agents SDK; left to the operator so this script stays + # idempotent and credential-agnostic. + endpoints[engine] = endpoint + return endpoints + + +def main() -> None: + catalog = os.environ.get("CATALOG", "main") + schema = os.environ.get("SCHEMA", "ontobricks") + warehouse_id = os.environ.get("WAREHOUSE_ID", "") + experiment = os.environ.get("ONTOBRICKS_MLFLOW_EXPERIMENT", "ontobricks-agents") + + if warehouse_id: + register_uc_function(catalog, schema, warehouse_id) + else: + logger.warning("WAREHOUSE_ID unset — skipping UC function registration") + + endpoints = deploy_engine_endpoints(experiment) + + config = SupervisorProvisioner.build_config( + catalog=catalog, + schema=schema, + pge_endpoint=endpoints["pge"], + simple_endpoint=endpoints["simple"], + ) + logger.info("Supervisor config built with %d agents", len(config["agents"])) + tile_id = SupervisorProvisioner.provision(config) + logger.info("Supervisor provisioned — tile_id=%s", tile_id) + + +if __name__ == "__main__": + main() diff --git a/src/agents/agent_supervisor/__init__.py b/src/agents/agent_supervisor/__init__.py new file mode 100644 index 00000000..25616e3d --- /dev/null +++ b/src/agents/agent_supervisor/__init__.py @@ -0,0 +1,25 @@ +"""Agent Bricks Supervisor for OntoBricks mapping-engine selection. + +Deterministically assesses a domain's complexity and routes the mapping task to +either the PGE engine (``agent_mapping_pge``) or the original simple engine +(``agent_auto_assignment``). Exposed to a Databricks Agent Bricks Multi-Agent +Supervisor via a complexity UC function (``uc_function.sql``) + per-engine Model +Serving endpoints (``responses_agent.py``), wired by ``mas.py``. +""" + +from agents.agent_supervisor.complexity import ( + ComplexityAssessor, + ComplexityReport, + assess, +) +from agents.agent_supervisor.engine import SupervisorEngine, SupervisorResult +from agents.agent_supervisor.mas import SupervisorProvisioner + +__all__ = [ + "ComplexityAssessor", + "ComplexityReport", + "assess", + "SupervisorEngine", + "SupervisorResult", + "SupervisorProvisioner", +] diff --git a/src/agents/agent_supervisor/complexity.py b/src/agents/agent_supervisor/complexity.py new file mode 100644 index 00000000..10c8e53b --- /dev/null +++ b/src/agents/agent_supervisor/complexity.py @@ -0,0 +1,244 @@ +"""Deterministic domain-complexity assessment for engine selection. + +The supervisor must decide, for a given domain, whether to run the heavyweight +PGE loop (planner → generator → evaluator → critic, multi-attempt, multi-replan) +or the lightweight single-agent engine. That decision is **deterministic** — it +is computed here from the source metadata + the generated ontology, never left +to an LLM's discretion — and then *exposed* to the Agent Bricks Supervisor as a +tool it calls before routing (see ``agent_supervisor/mas.py``). + +Why deterministic: a Multi-Agent Supervisor routes semantically over agent +descriptions, which is fine for "which specialist answers this question" but +unreliable for "is this domain complex enough to warrant the expensive engine". +We keep the hard threshold in code, register it as a Unity Catalog function +(``uc_function.sql``), and let the supervisor's natural-language instructions act +on its structured recommendation. + +The signals that make a domain *hard to map* — and therefore worth the PGE loop: + +* **Many source tables** — more SQL surface to plan and validate. +* **Cross-source reconciliation** — the same real-world entity realised by + several tables (one per source system / region / tenant) whose keys and column + names disagree. This is exactly what the PGE planner + semantic critic exist + for; the simple engine has no notion of it. +* **A large ontology** — many classes / object properties means many mapping + items, each a chance for a dangling endpoint the evaluator must catch. +* **Schema heterogeneity** — divergent naming conventions across tables (e.g. + ``MOTHER_NHS_NO`` vs ``mother_nhs_number``) signal multi-source feeds that need + normalization the PGE engine performs. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from typing import Dict, List + +from agents.pge_eval.normalize import ( + normalize_metadata, + normalize_name, + normalize_ontology, +) +from back.core.logging import get_logger + +logger = get_logger(__name__) + +# --- Tunable weights (sum to 1.0) and the decision threshold ---------------- +# Kept as module constants so the threshold is auditable and adjustable without +# touching logic. The UC function mirrors these values. +WEIGHT_TABLES = 0.20 +WEIGHT_CLASSES = 0.20 +WEIGHT_RELATIONSHIPS = 0.15 +WEIGHT_CROSS_SOURCE = 0.30 +WEIGHT_HETEROGENEITY = 0.15 + +# Saturation points — the count at which a signal contributes its full weight. +SATURATE_TABLES = 5 +SATURATE_CLASSES = 12 +SATURATE_RELATIONSHIPS = 10 + +# A domain scoring at/above this is routed to the PGE engine. +COMPLEXITY_THRESHOLD = 0.45 + +# An id-like column appearing in 2+ tables is the strongest cheap signal of +# cross-source reconciliation (the same entity keyed across feeds). +_ID_COLUMN_RE = re.compile(r"(^|_)(id|no|number|key|code|nhs|mrn|uuid)$") + + +@dataclass +class ComplexityReport: + """Structured, JSON-serialisable complexity verdict for one domain.""" + + score: float + tier: str # "simple" | "complex" + recommended_engine: str # "simple" | "pge" + signals: Dict[str, float] = field(default_factory=dict) + rationale: str = "" + + def to_dict(self) -> dict: + return { + "score": round(self.score, 4), + "tier": self.tier, + "recommended_engine": self.recommended_engine, + "signals": self.signals, + "rationale": self.rationale, + } + + +class ComplexityAssessor: + """Score a domain's mapping complexity and recommend an engine. + + Stateless; use the module-level :func:`assess` for the common path. + """ + + @staticmethod + def assess(metadata: dict, ontology: dict) -> ComplexityReport: + """Return a :class:`ComplexityReport` for *metadata* + *ontology*. + + Both arguments accept the same shapes the rest of the pipeline uses + (agent or registry shape); parsing is delegated to ``pge_eval.normalize`` + so this stays consistent with the evaluator. + """ + tables = normalize_metadata(metadata or {}) + onto = normalize_ontology(ontology or {}) + + n_tables = len(tables) + n_columns = sum(len(t.get("columns") or []) for t in tables) + n_classes = len(onto.classes) + n_relationships = len(onto.object_properties) + + cross_source = ComplexityAssessor._cross_source_score(tables) + heterogeneity = ComplexityAssessor._heterogeneity_score(tables) + + s_tables = min(n_tables / SATURATE_TABLES, 1.0) + s_classes = min(n_classes / SATURATE_CLASSES, 1.0) + s_rels = min(n_relationships / SATURATE_RELATIONSHIPS, 1.0) + + score = ( + WEIGHT_TABLES * s_tables + + WEIGHT_CLASSES * s_classes + + WEIGHT_RELATIONSHIPS * s_rels + + WEIGHT_CROSS_SOURCE * cross_source + + WEIGHT_HETEROGENEITY * heterogeneity + ) + + is_complex = score >= COMPLEXITY_THRESHOLD + tier = "complex" if is_complex else "simple" + engine = "pge" if is_complex else "simple" + + signals = { + "n_tables": n_tables, + "n_columns": n_columns, + "n_classes": n_classes, + "n_relationships": n_relationships, + "cross_source": round(cross_source, 4), + "heterogeneity": round(heterogeneity, 4), + } + rationale = ComplexityAssessor._rationale(tier, signals, score) + logger.info( + "Complexity assessment — score=%.3f tier=%s engine=%s signals=%s", + score, + tier, + engine, + signals, + ) + return ComplexityReport( + score=score, + tier=tier, + recommended_engine=engine, + signals=signals, + rationale=rationale, + ) + + @staticmethod + def _cross_source_score(tables: List[dict]) -> float: + """Fraction-style [0,1] signal that the same entity spans several tables. + + Strongest evidence is an id-like column shared across multiple tables + (the join key reconciling feeds). We also treat a high table-to-shared- + key ratio as cross-source. Returns 0.0 for a single table. + """ + if len(tables) < 2: + return 0.0 + + # Detect id-likeness on the raw (lowercased) name so the ``_id``/``_no`` + # suffix boundary survives, but GROUP by the normalized name so the same + # key written differently across feeds (MOTHER_NHS_NO vs mother_nhs_no) + # counts as one shared key. + id_col_tables: Dict[str, int] = {} + for t in tables: + seen_in_table = set() + for col in t.get("columns") or []: + raw = (col or "").lower() + key = normalize_name(col) + if _ID_COLUMN_RE.search(raw) and key and key not in seen_in_table: + id_col_tables[key] = id_col_tables.get(key, 0) + 1 + seen_in_table.add(key) + + shared_keys = [k for k, n in id_col_tables.items() if n >= 2] + if not shared_keys: + return 0.0 + + # How widely is the most-shared key spread across tables? + max_spread = max(id_col_tables[k] for k in shared_keys) + spread_ratio = max_spread / len(tables) + # Presence of a shared key is itself meaningful; spread scales it up. + return min(0.5 + 0.5 * spread_ratio, 1.0) + + @staticmethod + def _heterogeneity_score(tables: List[dict]) -> float: + """[0,1] signal of divergent column-naming conventions across tables. + + Mixed UPPER/lower/camel/snake conventions across feeds is a hallmark of + multi-source data needing the PGE engine's normalization. Returns the + fraction of distinct conventions observed beyond the first. + """ + if len(tables) < 2: + return 0.0 + + conventions = set() + for t in tables: + for col in t.get("columns") or []: + conventions.add(_naming_convention(col)) + conventions.discard("other") + if not conventions: + return 0.0 + # 1 convention → homogeneous (0.0); each extra convention adds signal. + return min((len(conventions) - 1) / 3.0, 1.0) + + @staticmethod + def _rationale(tier: str, signals: Dict[str, float], score: float) -> str: + drivers = [] + if signals["cross_source"] > 0: + drivers.append("a shared key across multiple tables (cross-source reconciliation)") + if signals["n_tables"] >= SATURATE_TABLES: + drivers.append(f"{signals['n_tables']} source tables") + if signals["n_classes"] >= SATURATE_CLASSES: + drivers.append(f"{signals['n_classes']} ontology classes") + if signals["heterogeneity"] > 0: + drivers.append("heterogeneous column-naming across feeds") + driver_text = "; ".join(drivers) if drivers else "a small, single-source schema" + return ( + f"Score {score:.2f} ({tier}). Drivers: {driver_text}. " + f"Recommended engine: {'PGE loop' if tier == 'complex' else 'simple single-agent'}." + ) + + +def _naming_convention(name: str) -> str: + """Classify a raw column name's casing convention.""" + if not name: + return "other" + if "_" in name: + return "upper_snake" if name.isupper() else "snake" + if name.isupper(): + return "upper" + if name[0].islower() and any(c.isupper() for c in name): + return "camel" + if name.islower(): + return "lower" + return "other" + + +def assess(metadata: dict, ontology: dict) -> ComplexityReport: + """Module-level convenience wrapper over :meth:`ComplexityAssessor.assess`.""" + return ComplexityAssessor.assess(metadata, ontology) diff --git a/src/agents/agent_supervisor/engine.py b/src/agents/agent_supervisor/engine.py new file mode 100644 index 00000000..f98642dc --- /dev/null +++ b/src/agents/agent_supervisor/engine.py @@ -0,0 +1,187 @@ +"""Supervisor engine - assess complexity, then dispatch to the right engine. + +This is the in-process brain the Agent Bricks Supervisor delegates to. Given a +task ("mapping" or "ontology") plus the domain's metadata and ontology, it runs +the deterministic ``ComplexityAssessor``, picks the engine, and invokes it via +``AgentClient``. + +Engine selection (deterministic): + +* "mapping" - the genuine two-engine choice. PGE: ``agent_mapping_pge``; + simple: ``agent_auto_assignment`` (the original single-agent engine from + ``master``). This is what the complexity score routes between. +* "ontology" - a single engine, ``agent_owl_generator`` (its PGE Evaluator stage + is bounded internally; there is no separate "simple ontology engine"). The + complexity report is still produced for observability, but dispatch is + unconditional. + +The mapping selection can be forced via ``engine_override`` for callers that +already know which engine they want (e.g. the supervisor acting on its own +routing decision). +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Optional, Tuple + +from agents.agent_supervisor.complexity import ComplexityReport, assess +from back.core.agents.AgentClient import get_agent_client +from back.core.logging import get_logger + +logger = get_logger(__name__) + +_VALID_TASKS = ("mapping", "ontology") +_VALID_ENGINES = ("pge", "simple") + + +@dataclass +class SupervisorResult: + """Outcome of a supervised run: the routing decision + the engine result.""" + + task: str + engine_used: str + complexity: ComplexityReport + result: Any = None + success: bool = False + error: str = "" + extras: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict: + return { + "task": self.task, + "engine_used": self.engine_used, + "success": self.success, + "error": self.error, + "complexity": self.complexity.to_dict() if self.complexity else None, + "extras": self.extras, + } + + +class SupervisorEngine: + """Routes a domain task to the PGE or simple engine by complexity.""" + + @staticmethod + def decide_engine( + metadata: dict, + ontology: dict, + engine_override: Optional[str] = None, + ) -> Tuple[str, ComplexityReport]: + """Return ``(engine, report)``. + + ``engine_override`` short-circuits the recommendation but the report is + still computed for observability. + """ + report = assess(metadata, ontology) + if engine_override in _VALID_ENGINES: + return engine_override, report + return report.recommended_engine, report + + @staticmethod + def run( + *, + task: str, + host: str, + token: str, + endpoint_name: str, + metadata: dict, + ontology: dict, + engine_override: Optional[str] = None, + client: Any = None, + entity_mappings: Any = None, + relationship_mappings: Any = None, + base_uri: str = "", + selected_tables: Optional[list] = None, + documents: Any = None, + on_step: Optional[Callable] = None, + ) -> SupervisorResult: + """Assess complexity, choose an engine, and run it. + + ``task`` is ``"mapping"`` or ``"ontology"``. Engine-specific arguments are + forwarded to the chosen engine via :class:`AgentClient`. + """ + if task not in _VALID_TASKS: + raise ValueError(f"task must be one of {_VALID_TASKS}, got {task!r}") + + engine, report = SupervisorEngine.decide_engine( + metadata, ontology, engine_override + ) + agent = get_agent_client() + + if task == "mapping": + engine_used = engine + logger.info("Supervisor routing - task=mapping engine=%s", engine) + else: + # Ontology generation has a single engine; report.recommended_engine + # is advisory only. + engine_used = "owl_generator" + logger.info( + "Supervisor routing - task=ontology engine=owl_generator " + "(complexity tier=%s, advisory)", + report.tier, + ) + + try: + if task == "mapping": + result = SupervisorEngine._run_mapping( + agent, + engine, + host=host, + token=token, + endpoint_name=endpoint_name, + client=client, + metadata=metadata, + ontology=ontology, + entity_mappings=entity_mappings or [], + relationship_mappings=relationship_mappings or [], + documents=documents, + on_step=on_step, + ) + else: + result = agent.run_owl_generator( + host=host, + token=token, + endpoint_name=endpoint_name, + base_uri=base_uri, + selected_tables=selected_tables or [], + metadata=metadata, + ontology=ontology, + on_step=on_step, + ) + except Exception as exc: # surfaced to the caller; never swallowed silently + logger.error( + "Supervisor run failed (task=%s engine=%s): %s", task, engine_used, exc + ) + return SupervisorResult( + task=task, + engine_used=engine_used, + complexity=report, + success=False, + error=str(exc), + ) + + return SupervisorResult( + task=task, + engine_used=engine_used, + complexity=report, + result=result, + success=bool(getattr(result, "success", True)), + ) + + @staticmethod + def _run_mapping(agent, engine, **kw): + common = dict( + host=kw["host"], + token=kw["token"], + endpoint_name=kw["endpoint_name"], + client=kw["client"], + metadata=kw["metadata"], + ontology=kw["ontology"], + entity_mappings=kw["entity_mappings"], + relationship_mappings=kw["relationship_mappings"], + documents=kw["documents"], + on_step=kw["on_step"], + ) + if engine == "pge": + return agent.run_mapping_pge(**common) + return agent.run_auto_assignment(**common) diff --git a/src/agents/agent_supervisor/log_model.py b/src/agents/agent_supervisor/log_model.py new file mode 100644 index 00000000..ad23c13e --- /dev/null +++ b/src/agents/agent_supervisor/log_model.py @@ -0,0 +1,63 @@ +"""Log the mapping-engine ResponsesAgents to MLflow + deploy as endpoints. + +Logs ``MappingEngineResponsesAgent`` twice - once per engine - so the Agent +Bricks Supervisor can route between two Model Serving endpoints. + +Usage:: + + # From the OntoBricks repository root + python -m agents.agent_supervisor.log_model + + ONTOBRICKS_MLFLOW_EXPERIMENT=my-exp python -m agents.agent_supervisor.log_model +""" + +import os +import sys + +from back.core.logging import get_logger + +logger = get_logger(__name__) + +sys.path.insert( + 0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) + +_INPUT_EXAMPLE = { + "input": [{"role": "user", "content": "Map this domain."}], + "custom_inputs": { + "mode": "assess", + "metadata": {"tables": [{"name": "patients", "columns": ["id", "name"]}]}, + "ontology": {"classes": [{"name": "Patient"}], "properties": []}, + }, +} + + +def log_engine_agent(engine: str, experiment_name: str = "ontobricks-agents") -> str: + """Log a mapping-engine ResponsesAgent for ``engine`` ('pge' | 'simple'). + + Returns the model URI ``runs://mapping-``. + """ + import mlflow + + if engine not in ("pge", "simple"): + raise ValueError(f"engine must be 'pge' or 'simple', got {engine!r}") + + mlflow.set_experiment(experiment_name) + artifact = f"mapping-{engine}" + with mlflow.start_run(run_name=f"supervisor-mapping-{engine}") as run: + mlflow.pyfunc.log_model( + python_model="agents/agent_supervisor/responses_agent.py", + name=artifact, + model_config={"engine": engine}, + input_example=_INPUT_EXAMPLE, + ) + model_uri = f"runs:/{run.info.run_id}/{artifact}" + logger.info("Logged %s engine agent - URI: %s", engine, model_uri) + return model_uri + + +if __name__ == "__main__": + experiment = os.getenv("ONTOBRICKS_MLFLOW_EXPERIMENT", "ontobricks-agents") + for eng in ("pge", "simple"): + uri = log_engine_agent(eng, experiment) + logger.info("Done %s -> %s", eng, uri) diff --git a/src/agents/agent_supervisor/mas.py b/src/agents/agent_supervisor/mas.py new file mode 100644 index 00000000..349bfc83 --- /dev/null +++ b/src/agents/agent_supervisor/mas.py @@ -0,0 +1,174 @@ +"""Agent Bricks Supervisor (MAS) configuration + provisioning. + +Builds the Multi-Agent Supervisor that orchestrates OntoBricks mapping. The +supervisor wires three agents: + +1. ``complexity_assessor`` - the deterministic UC function (``uc_function.sql``) + that scores a domain and recommends an engine. The supervisor calls this + FIRST. +2. ``pge_mapping`` - the Model Serving endpoint wrapping ``agent_mapping_pge``. +3. ``simple_mapping`` - the Model Serving endpoint wrapping + ``agent_auto_assignment`` (the original single-agent engine). + +Routing is the requested hybrid: a deterministic UC function provides the hard +recommendation, and natural-language instructions tell the supervisor to act on +it. The supervisor reads ``recommended_engine`` from the assessor and routes to +the matching mapping endpoint. + +``build_config`` is pure (no I/O) so it can be unit-tested; ``provision`` applies +it via the Agent Bricks ``manage_mas`` MCP tool / SDK at deploy time. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, List + +from back.core.logging import get_logger + +logger = get_logger(__name__) + +DEFAULT_SUPERVISOR_NAME = "OntoBricks Mapping Supervisor" + +ROUTING_INSTRUCTIONS = """\ +You orchestrate entity/relationship mapping for an OntoBricks domain. You have +three tools: + +1. complexity_assessor - a deterministic function that scores a domain's mapping + complexity from its source metadata + ontology and returns JSON with a + `recommended_engine` field ("pge" or "simple"). +2. pge_mapping - the heavyweight Planner-Generator-Evaluator mapping engine. Use + for COMPLEX domains: many source tables, cross-source reconciliation (the same + entity keyed across several feeds), large ontologies, or heterogeneous column + naming. It plans a source model and gates every mapping with a deterministic + evaluator plus a semantic critic. +3. simple_mapping - the lightweight single-agent engine. Use for SIMPLE domains: + a single source, few tables, a small ontology, and uniform schema. + +ALWAYS follow this procedure: +- Step 1: call complexity_assessor with the domain's metadata and ontology. +- Step 2: read `recommended_engine` from its JSON response. +- Step 3: if it is "pge", route the mapping task to pge_mapping; if "simple", + route to simple_mapping. Pass the domain context through unchanged. +- Never skip the assessor. Never pick an engine on your own judgement when the + assessor has given a recommendation; the assessor's verdict is authoritative. +""" + + +@dataclass +class SupervisorAgentRef: + """One agent entry in the MAS config.""" + + name: str + description: str + uc_function_name: str = "" + endpoint_name: str = "" + + def to_dict(self) -> dict: + d = {"name": self.name, "description": self.description} + if self.uc_function_name: + d["uc_function_name"] = self.uc_function_name + if self.endpoint_name: + d["endpoint_name"] = self.endpoint_name + return d + + +class SupervisorProvisioner: + """Build and provision the OntoBricks mapping Supervisor Agent.""" + + @staticmethod + def build_config( + *, + catalog: str, + schema: str, + pge_endpoint: str, + simple_endpoint: str, + name: str = DEFAULT_SUPERVISOR_NAME, + ) -> dict: + """Return the ``manage_mas`` create-or-update payload (pure, no I/O).""" + assessor_fn = f"{catalog}.{schema}.assess_domain_complexity" + agents: List[SupervisorAgentRef] = [ + SupervisorAgentRef( + name="complexity_assessor", + uc_function_name=assessor_fn, + description=( + "Deterministically scores a domain's mapping complexity from its " + "source metadata and ontology. Returns JSON including " + "recommended_engine ('pge' or 'simple'). CALL THIS FIRST to decide " + "which mapping engine to use." + ), + ), + SupervisorAgentRef( + name="pge_mapping", + endpoint_name=pge_endpoint, + description=( + "Heavyweight Planner-Generator-Evaluator mapping engine for COMPLEX " + "domains: many tables, cross-source reconciliation, large or " + "heterogeneous schemas. Plans a source model and gates each mapping " + "with a deterministic evaluator + semantic critic." + ), + ), + SupervisorAgentRef( + name="simple_mapping", + endpoint_name=simple_endpoint, + description=( + "Lightweight single-agent mapping engine for SIMPLE domains: a single " + "source, few tables, a small ontology, uniform schema. Fast; no " + "planning or independent evaluation." + ), + ), + ] + return { + "name": name, + "description": ( + "Routes OntoBricks entity/relationship mapping to the PGE or the simple " + "engine based on a deterministic complexity assessment." + ), + "instructions": ROUTING_INSTRUCTIONS, + "agents": [a.to_dict() for a in agents], + "examples": SupervisorProvisioner._examples(), + } + + @staticmethod + def _examples() -> List[Dict[str, str]]: + return [ + { + "question": ( + "Map this domain: 3 trust feeds (trust_a, trust_b, trust_c) sharing " + "MOTHER_NHS_NO, ~17 ontology classes." + ), + "guideline": ( + "Call complexity_assessor; it returns recommended_engine='pge' " + "(cross-source + large ontology). Route to pge_mapping." + ), + }, + { + "question": ( + "Map this domain: one table 'patients' with 6 columns, 2 ontology " + "classes." + ), + "guideline": ( + "Call complexity_assessor; it returns recommended_engine='simple'. " + "Route to simple_mapping." + ), + }, + ] + + @staticmethod + def provision(config: dict) -> str: + """Create/update the Supervisor Agent from *config*. + + Uses the Agent Bricks ``manage_mas`` capability. Kept import-local and + best-effort so the module imports cleanly in environments without the + Agent Bricks SDK (e.g. unit tests); raises if provisioning is attempted + without it. + """ + try: + from databricks.agents import mas # type: ignore + except Exception as exc: # pragma: no cover - deploy-time path + raise RuntimeError( + "Agent Bricks SDK not available; provision via the manage_mas MCP " + "tool or run inside a Databricks environment." + ) from exc + logger.info("Provisioning Supervisor Agent %r", config.get("name")) + return mas.create_or_update(**config) # pragma: no cover diff --git a/src/agents/agent_supervisor/responses_agent.py b/src/agents/agent_supervisor/responses_agent.py new file mode 100644 index 00000000..d6bf7735 --- /dev/null +++ b/src/agents/agent_supervisor/responses_agent.py @@ -0,0 +1,122 @@ +"""Mapping-engine MLflow ResponsesAgent wrappers for Model Serving. + +Wraps the mapping engines as ``mlflow.pyfunc.ResponsesAgent`` models so each can +be logged and served as its own Model Serving endpoint. The Agent Bricks +Supervisor then references the endpoints by name and routes between them using +the complexity assessor's recommendation (see ``mas.py`` / ``uc_function.sql``). + +Two endpoints are produced from the SAME class, parameterised by ``engine``: + +* ``engine="pge"`` -> ``agent_mapping_pge`` (planner/generator/evaluator/critic) +* ``engine="simple"`` -> ``agent_auto_assignment`` (original single-agent engine) + +The mapping run is long (minutes for a large domain). A serving endpoint must +not block indefinitely, so this wrapper supports two modes via +``custom_inputs.mode``: + +* ``"assess"`` (default when no SQL client is supplied) - run only the + deterministic complexity assessment and return the recommendation. Cheap, + always fast; lets a caller preview routing without running an engine. +* ``"run"`` - execute the wrapped engine and return the mapping result. Intended + for callers that drive it as a background task. + +The heavy lifting lives unchanged in the engine packages; this is a thin, +serving-friendly adapter. +""" + +import copy +from typing import Generator +from uuid import uuid4 + +import mlflow +from mlflow.models import ModelConfig +from mlflow.pyfunc import ResponsesAgent +from mlflow.types.responses import ( + ResponsesAgentRequest, + ResponsesAgentResponse, + ResponsesAgentStreamEvent, +) + +from agents.agent_supervisor.complexity import assess +from agents.agent_supervisor.engine import SupervisorEngine +from back.core.logging import get_logger + +logger = get_logger(__name__) + + +class MappingEngineResponsesAgent(ResponsesAgent): + """Serve one mapping engine (``pge`` or ``simple``) behind Model Serving. + + The engine is fixed per deployed endpoint via the model config key + ``engine`` (defaults to ``"pge"``), so the same code logs two endpoints. + """ + + def __init__(self) -> None: + try: + cfg = ModelConfig(development_config={"engine": "pge"}) + self._engine = cfg.get("engine") or "pge" + except Exception: # no config bound (e.g. unit test) -> default + self._engine = "pge" + + def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse: + outputs, custom_outputs = [], {} + for event in self.predict_stream(request): + if event.type == "response.output_item.done": + outputs.append(event.item) + if getattr(event, "custom_outputs", None): + custom_outputs.update(event.custom_outputs) + return ResponsesAgentResponse(output=outputs, custom_outputs=custom_outputs) + + def predict_stream( + self, request: ResponsesAgentRequest + ) -> Generator[ResponsesAgentStreamEvent, None, None]: + ci = request.custom_inputs or {} + metadata = copy.deepcopy(ci.get("metadata", {})) + ontology = copy.deepcopy(ci.get("ontology", {})) + mode = ci.get("mode") or ("run" if ci.get("client") is not None else "assess") + + report = assess(metadata, ontology) + engine = ci.get("engine_override") or self._engine + + if mode == "assess": + text = ( + f"Complexity {report.score:.2f} ({report.tier}). " + f"Recommended engine: {report.recommended_engine}. {report.rationale}" + ) + yield self._text_event(text, custom_outputs={"complexity": report.to_dict()}) + return + + if not ci.get("host") or not ci.get("token") or not ci.get("endpoint_name"): + yield self._text_event( + "Error: 'run' mode needs host, token, and endpoint_name in custom_inputs." + ) + return + + result = SupervisorEngine.run( + task="mapping", + host=ci["host"], + token=ci["token"], + endpoint_name=ci["endpoint_name"], + metadata=metadata, + ontology=ontology, + engine_override=engine, + client=ci.get("client"), + entity_mappings=ci.get("entity_mappings"), + relationship_mappings=ci.get("relationship_mappings"), + documents=ci.get("documents"), + ) + yield self._text_event( + f"Mapping run via '{result.engine_used}' engine - success={result.success}.", + custom_outputs=result.to_dict(), + ) + + def _text_event(self, text: str, custom_outputs: dict = None) -> ResponsesAgentStreamEvent: + return ResponsesAgentStreamEvent( + type="response.output_item.done", + item=self.create_text_output_item(text=text, id=f"msg_{uuid4().hex[:8]}"), + custom_outputs=custom_outputs or {}, + ) + + +agent = MappingEngineResponsesAgent() +mlflow.models.set_model(agent) diff --git a/src/agents/agent_supervisor/uc_function.sql b/src/agents/agent_supervisor/uc_function.sql new file mode 100644 index 00000000..d70e846b --- /dev/null +++ b/src/agents/agent_supervisor/uc_function.sql @@ -0,0 +1,109 @@ +-- ============================================================================ +-- assess_domain_complexity — deterministic engine-routing UC function +-- ============================================================================ +-- Registered as a Unity Catalog function and added to the Agent Bricks +-- Supervisor as a tool. The supervisor calls this FIRST with the domain's +-- source metadata + generated ontology, then routes the mapping task to the +-- PGE or the simple engine per the returned `recommended_engine`. +-- +-- This is a self-contained Python mirror of +-- `agents/agent_supervisor/complexity.py`. The weights/thresholds below MUST be +-- kept in sync with that module (the unit test `test_uc_function_parity` +-- guards the shared constants). Self-contained because UC Python functions run +-- sandboxed and cannot import the application package. +-- +-- Replace ${CATALOG} / ${SCHEMA} at deploy time (see mas.py). +-- ---------------------------------------------------------------------------- + +CREATE OR REPLACE FUNCTION ${CATALOG}.${SCHEMA}.assess_domain_complexity( + metadata_json STRING COMMENT 'Domain source metadata: {"tables":[{"name","columns":[...]}]}', + ontology_json STRING COMMENT 'Generated ontology: {"classes":[...],"properties":[...]} or agent shape' +) +RETURNS STRING +LANGUAGE PYTHON +COMMENT 'Deterministically score a domain''s mapping complexity and recommend the PGE or simple engine. Returns JSON {score, tier, recommended_engine, signals, rationale}.' +AS $$ +import json, re + +W_TABLES, W_CLASSES, W_RELS, W_CROSS, W_HET = 0.20, 0.20, 0.15, 0.30, 0.15 +SAT_TABLES, SAT_CLASSES, SAT_RELS = 5, 12, 10 +THRESHOLD = 0.45 +ID_RE = re.compile(r"(^|_)(id|no|number|key|code|nhs|mrn|uuid)$") + +def _norm(s): + return re.sub(r"[^a-z0-9]", "", (s or "").lower()) + +def _convention(name): + if not name: + return "other" + if "_" in name: + return "upper_snake" if name.isupper() else "snake" + if name.isupper(): + return "upper" + if name[0].islower() and any(c.isupper() for c in name): + return "camel" + if name.islower(): + return "lower" + return "other" + +try: + meta = json.loads(metadata_json) if metadata_json else {} +except Exception: + meta = {} +try: + onto = json.loads(ontology_json) if ontology_json else {} +except Exception: + onto = {} + +tables = meta.get("tables") or [] +def _cols(t): + return t.get("columns") or [c.get("name") if isinstance(c, dict) else c for c in (t.get("schema") or [])] + +n_tables = len(tables) +n_columns = sum(len(_cols(t) or []) for t in tables) +classes = onto.get("classes") or onto.get("entities") or [] +rels = onto.get("properties") or onto.get("relationships") or [] +# object properties only, if registry shape mixes data+object properties +n_classes = len(classes) +n_rels = len([r for r in rels if (r.get("type") in (None, "object", "ObjectProperty")) or "range" in r]) if rels and isinstance(rels[0], dict) else len(rels) + +# cross-source: an id-like column shared across >=2 tables +id_spread = {} +for t in tables: + seen = set() + for c in (_cols(t) or []): + name = c if isinstance(c, str) else c.get("name", "") + raw = (name or "").lower() + key = _norm(name) + if ID_RE.search(raw) and key and key not in seen: + id_spread[key] = id_spread.get(key, 0) + 1 + seen.add(key) +shared = [k for k, v in id_spread.items() if v >= 2] +cross = 0.0 +if n_tables >= 2 and shared: + cross = min(0.5 + 0.5 * (max(id_spread[k] for k in shared) / n_tables), 1.0) + +# heterogeneity: distinct naming conventions across feeds +convs = set() +for t in tables: + for c in (_cols(t) or []): + convs.add(_convention(c if isinstance(c, str) else c.get("name", ""))) +convs.discard("other") +het = 0.0 if (n_tables < 2 or not convs) else min((len(convs) - 1) / 3.0, 1.0) + +score = (W_TABLES * min(n_tables / SAT_TABLES, 1.0) + + W_CLASSES * min(n_classes / SAT_CLASSES, 1.0) + + W_RELS * min(n_rels / SAT_RELS, 1.0) + + W_CROSS * cross + W_HET * het) +tier = "complex" if score >= THRESHOLD else "simple" +engine = "pge" if tier == "complex" else "simple" +return json.dumps({ + "score": round(score, 4), + "tier": tier, + "recommended_engine": engine, + "signals": {"n_tables": n_tables, "n_columns": n_columns, "n_classes": n_classes, + "n_relationships": n_rels, "cross_source": round(cross, 4), + "heterogeneity": round(het, 4)}, + "rationale": f"score {round(score,2)} ({tier}); recommend {engine} engine", +}) +$$; diff --git a/tests/agents/agent_supervisor/__init__.py b/tests/agents/agent_supervisor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/agents/agent_supervisor/test_complexity.py b/tests/agents/agent_supervisor/test_complexity.py new file mode 100644 index 00000000..d2b96d00 --- /dev/null +++ b/tests/agents/agent_supervisor/test_complexity.py @@ -0,0 +1,134 @@ +"""Unit tests for the deterministic complexity assessor + baseline routing.""" + +import json +import re +from pathlib import Path + +import pytest + +from agents.agent_supervisor.complexity import ( + COMPLEXITY_THRESHOLD, + ComplexityAssessor, + assess, +) + +pytestmark = pytest.mark.unit + +_BASELINE = ( + Path(__file__).resolve().parents[2] + / "eval" + / "datasets" + / "agent_supervisor" + / "baseline.jsonl" +) +_UC_SQL = ( + Path(__file__).resolve().parents[3] + / "src" + / "agents" + / "agent_supervisor" + / "uc_function.sql" +) + + +def test_single_table_small_ontology_is_simple(): + report = assess( + {"tables": [{"name": "patients", "columns": ["id", "name", "dob"]}]}, + {"classes": [{"name": "Patient", "attributes": ["name", "dob"]}], "properties": []}, + ) + assert report.tier == "simple" + assert report.recommended_engine == "simple" + assert report.score < COMPLEXITY_THRESHOLD + + +def test_three_sources_sharing_key_is_complex(): + md = { + "tables": [ + {"name": "trust_a", "columns": ["EPISODE_ID", "MOTHER_NHS_NO", "DELIVERY_DATE"]}, + {"name": "trust_b", "columns": ["pregnancy_id", "mother_nhs_no", "booking_date"]}, + {"name": "trust_c", "columns": ["event_id", "mother_nhs_number", "event_date"]}, + ] + } + onto = { + "classes": [{"name": n} for n in ("Mother", "Baby", "Pregnancy", "Delivery", "Labour")], + "properties": [{"name": "hasBaby", "domain": "Mother", "range": "Baby"}], + } + report = assess(md, onto) + assert report.tier == "complex" + assert report.recommended_engine == "pge" + assert report.signals["cross_source"] > 0 + + +def test_cross_source_zero_for_single_table(): + report = assess( + {"tables": [{"name": "t", "columns": ["id", "x"]}]}, + {"classes": [{"name": "T"}], "properties": []}, + ) + assert report.signals["cross_source"] == 0.0 + + +def test_assessment_is_deterministic(): + md = {"tables": [{"name": "a", "columns": ["id", "v"]}, {"name": "b", "columns": ["id", "w"]}]} + onto = {"classes": [{"name": "A"}, {"name": "B"}], "properties": []} + first = assess(md, onto).to_dict() + for _ in range(5): + assert assess(md, onto).to_dict() == first + + +def test_empty_inputs_do_not_crash(): + report = assess({}, {}) + assert report.tier == "simple" + assert report.recommended_engine == "simple" + + +def test_report_is_json_serialisable(): + report = assess({"tables": [{"name": "t", "columns": ["id"]}]}, {"classes": [], "properties": []}) + json.dumps(report.to_dict()) # must not raise + + +def _load_baseline(): + rows = [] + with _BASELINE.open() as fh: + for line in fh: + line = line.strip() + if line: + rows.append(json.loads(line)) + return rows + + +def test_baseline_dataset_has_min_examples(): + rows = _load_baseline() + assert len(rows) >= 20, "eval-gate requires >= 20 examples for a new agent" + + +@pytest.mark.parametrize("row", _load_baseline(), ids=lambda r: r["id"]) +def test_baseline_routing_accuracy(row): + """Every baseline case must route to its expected engine (accuracy == 1.0).""" + report = assess(row["input"]["metadata"], row["input"]["ontology"]) + assert report.recommended_engine == row["expected"]["recommended_engine"], ( + f"{row['id']}: got {report.recommended_engine} " + f"(score={report.score:.3f}, signals={report.signals})" + ) + + +def test_uc_function_parity(): + """The UC function must embed the same weights/threshold as complexity.py.""" + from agents.agent_supervisor import complexity as c + + sql = _UC_SQL.read_text() + # Constants line: W_TABLES, W_CLASSES, W_RELS, W_CROSS, W_HET = ... + weights = re.search( + r"W_TABLES,\s*W_CLASSES,\s*W_RELS,\s*W_CROSS,\s*W_HET\s*=\s*" + r"([0-9.]+),\s*([0-9.]+),\s*([0-9.]+),\s*([0-9.]+),\s*([0-9.]+)", + sql, + ) + assert weights, "weights constant not found in uc_function.sql" + got = [float(x) for x in weights.groups()] + assert got == [ + c.WEIGHT_TABLES, + c.WEIGHT_CLASSES, + c.WEIGHT_RELATIONSHIPS, + c.WEIGHT_CROSS_SOURCE, + c.WEIGHT_HETEROGENEITY, + ] + threshold = re.search(r"THRESHOLD\s*=\s*([0-9.]+)", sql) + assert threshold and float(threshold.group(1)) == c.COMPLEXITY_THRESHOLD diff --git a/tests/agents/agent_supervisor/test_engine.py b/tests/agents/agent_supervisor/test_engine.py new file mode 100644 index 00000000..7081802b --- /dev/null +++ b/tests/agents/agent_supervisor/test_engine.py @@ -0,0 +1,165 @@ +"""Unit tests for the supervisor engine selection + dispatch.""" + +import pytest + +from agents.agent_supervisor import mas as mas_mod +from agents.agent_supervisor.engine import SupervisorEngine + +pytestmark = pytest.mark.unit + + +class _FakeResult: + def __init__(self): + self.success = True + + +class _FakeAgentClient: + def __init__(self): + self.calls = [] + + def run_mapping_pge(self, **kw): + self.calls.append(("pge", kw)) + return _FakeResult() + + def run_auto_assignment(self, **kw): + self.calls.append(("simple", kw)) + return _FakeResult() + + def run_owl_generator(self, **kw): + self.calls.append(("owl", kw)) + return _FakeResult() + + +_SIMPLE_MD = {"tables": [{"name": "t", "columns": ["id", "x"]}]} +_SIMPLE_ONTO = {"classes": [{"name": "T"}], "properties": []} +_COMPLEX_MD = { + "tables": [ + {"name": "a", "columns": ["entity_id", "p"]}, + {"name": "b", "columns": ["entity_id", "q"]}, + {"name": "c", "columns": ["ENTITY_ID", "r"]}, + ] +} +_COMPLEX_ONTO = { + "classes": [{"name": n} for n in ("A", "B", "C", "D", "E", "F", "G", "H")], + "properties": [{"name": "rel", "domain": "A", "range": "B"}], +} + + +@pytest.fixture +def fake_client(monkeypatch): + client = _FakeAgentClient() + monkeypatch.setattr( + "agents.agent_supervisor.engine.get_agent_client", lambda: client + ) + return client + + +def test_invalid_task_raises(fake_client): + with pytest.raises(ValueError): + SupervisorEngine.run( + task="nonsense", + host="h", + token="t", + endpoint_name="e", + metadata=_SIMPLE_MD, + ontology=_SIMPLE_ONTO, + ) + + +def test_simple_domain_routes_to_simple_engine(fake_client): + res = SupervisorEngine.run( + task="mapping", + host="h", + token="t", + endpoint_name="e", + metadata=_SIMPLE_MD, + ontology=_SIMPLE_ONTO, + client=object(), + ) + assert res.engine_used == "simple" + assert fake_client.calls[0][0] == "simple" + assert res.success + + +def test_complex_domain_routes_to_pge_engine(fake_client): + res = SupervisorEngine.run( + task="mapping", + host="h", + token="t", + endpoint_name="e", + metadata=_COMPLEX_MD, + ontology=_COMPLEX_ONTO, + client=object(), + ) + assert res.engine_used == "pge" + assert fake_client.calls[0][0] == "pge" + + +def test_engine_override_forces_engine(fake_client): + res = SupervisorEngine.run( + task="mapping", + host="h", + token="t", + endpoint_name="e", + metadata=_COMPLEX_MD, + ontology=_COMPLEX_ONTO, + engine_override="simple", + client=object(), + ) + assert res.engine_used == "simple" + assert fake_client.calls[0][0] == "simple" + # complexity report is still computed for observability + assert res.complexity.tier == "complex" + + +def test_ontology_task_uses_single_engine(fake_client): + res = SupervisorEngine.run( + task="ontology", + host="h", + token="t", + endpoint_name="e", + metadata=_COMPLEX_MD, + ontology=_COMPLEX_ONTO, + base_uri="http://x#", + selected_tables=["a", "b"], + ) + assert res.engine_used == "owl_generator" + assert fake_client.calls[0][0] == "owl" + + +def test_engine_failure_is_surfaced(monkeypatch): + class _Boom(_FakeAgentClient): + def run_auto_assignment(self, **kw): + raise RuntimeError("engine exploded") + + monkeypatch.setattr( + "agents.agent_supervisor.engine.get_agent_client", lambda: _Boom() + ) + res = SupervisorEngine.run( + task="mapping", + host="h", + token="t", + endpoint_name="e", + metadata=_SIMPLE_MD, + ontology=_SIMPLE_ONTO, + client=object(), + ) + assert res.success is False + assert "engine exploded" in res.error + + +def test_mas_config_shape(): + cfg = mas_mod.SupervisorProvisioner.build_config( + catalog="cat", + schema="sch", + pge_endpoint="ob-mapping-pge", + simple_endpoint="ob-mapping-simple", + ) + names = [a["name"] for a in cfg["agents"]] + assert names == ["complexity_assessor", "pge_mapping", "simple_mapping"] + assessor = cfg["agents"][0] + assert assessor["uc_function_name"] == "cat.sch.assess_domain_complexity" + assert cfg["agents"][1]["endpoint_name"] == "ob-mapping-pge" + assert cfg["agents"][2]["endpoint_name"] == "ob-mapping-simple" + assert "complexity_assessor" in cfg["instructions"] + assert len(cfg["examples"]) >= 2 diff --git a/tests/eval/datasets/agent_supervisor/baseline.jsonl b/tests/eval/datasets/agent_supervisor/baseline.jsonl new file mode 100644 index 00000000..8f755936 --- /dev/null +++ b/tests/eval/datasets/agent_supervisor/baseline.jsonl @@ -0,0 +1,20 @@ +{"id": "simple-single-table-001", "input": {"metadata": {"tables": [{"name": "patients", "columns": ["id", "name", "dob"]}]}, "ontology": {"classes": [{"name": "Patient", "attributes": ["name", "dob"]}], "properties": []}}, "expected": {"recommended_engine": "simple"}, "tags": ["simple", "single-source"]} +{"id": "simple-single-table-002", "input": {"metadata": {"tables": [{"name": "orders", "columns": ["order_id", "amount", "status"]}]}, "ontology": {"classes": [{"name": "Order", "attributes": ["amount", "status"]}], "properties": []}}, "expected": {"recommended_engine": "simple"}, "tags": ["simple", "single-source"]} +{"id": "simple-two-table-no-shared-003", "input": {"metadata": {"tables": [{"name": "products", "columns": ["product_id", "title"]}, {"name": "categories", "columns": ["category_id", "label"]}]}, "ontology": {"classes": [{"name": "Product", "attributes": ["title"]}, {"name": "Category", "attributes": ["label"]}], "properties": [{"name": "belongsTo", "domain": "Product", "range": "Category"}]}}, "expected": {"recommended_engine": "simple"}, "tags": ["simple"]} +{"id": "simple-small-onto-004", "input": {"metadata": {"tables": [{"name": "employees", "columns": ["emp_id", "first_name", "last_name", "title"]}]}, "ontology": {"classes": [{"name": "Employee", "attributes": ["first_name", "last_name", "title"]}], "properties": []}}, "expected": {"recommended_engine": "simple"}, "tags": ["simple"]} +{"id": "simple-single-source-005", "input": {"metadata": {"tables": [{"name": "invoices", "columns": ["invoice_id", "total", "issued_date"]}]}, "ontology": {"classes": [{"name": "Invoice", "attributes": ["total", "issued_date"]}], "properties": []}}, "expected": {"recommended_engine": "simple"}, "tags": ["simple"]} +{"id": "simple-two-class-006", "input": {"metadata": {"tables": [{"name": "books", "columns": ["book_id", "title", "author_name"]}]}, "ontology": {"classes": [{"name": "Book", "attributes": ["title"]}, {"name": "Author", "attributes": ["author_name"]}], "properties": [{"name": "writtenBy", "domain": "Book", "range": "Author"}]}}, "expected": {"recommended_engine": "simple"}, "tags": ["simple"]} +{"id": "simple-tiny-007", "input": {"metadata": {"tables": [{"name": "devices", "columns": ["device_id", "model"]}]}, "ontology": {"classes": [{"name": "Device", "attributes": ["model"]}], "properties": []}}, "expected": {"recommended_engine": "simple"}, "tags": ["simple"]} +{"id": "simple-homogeneous-008", "input": {"metadata": {"tables": [{"name": "accounts", "columns": ["account_id", "balance", "opened_date"]}]}, "ontology": {"classes": [{"name": "Account", "attributes": ["balance", "opened_date"]}], "properties": []}}, "expected": {"recommended_engine": "simple"}, "tags": ["simple"]} +{"id": "simple-two-tables-distinct-keys-009", "input": {"metadata": {"tables": [{"name": "stores", "columns": ["store_id", "city"]}, {"name": "regions", "columns": ["region_id", "name"]}]}, "ontology": {"classes": [{"name": "Store", "attributes": ["city"]}, {"name": "Region", "attributes": ["name"]}], "properties": [{"name": "inRegion", "domain": "Store", "range": "Region"}]}}, "expected": {"recommended_engine": "simple"}, "tags": ["simple"]} +{"id": "simple-modest-010", "input": {"metadata": {"tables": [{"name": "tickets", "columns": ["ticket_id", "subject", "priority"]}]}, "ontology": {"classes": [{"name": "Ticket", "attributes": ["subject", "priority"]}, {"name": "Agent", "attributes": ["agent_name"]}], "properties": [{"name": "assignedTo", "domain": "Ticket", "range": "Agent"}]}}, "expected": {"recommended_engine": "simple"}, "tags": ["simple"]} +{"id": "complex-three-trust-shared-nhs-011", "input": {"metadata": {"tables": [{"name": "trust_a_episode", "columns": ["EPISODE_ID", "MOTHER_NHS_NO", "BABY_NHS_NO", "DELIVERY_DATE", "MODE_OF_DELIVERY"]}, {"name": "trust_b_pregnancy", "columns": ["pregnancy_id", "mother_nhs_no", "booking_date", "gestation_weeks"]}, {"name": "trust_c_event", "columns": ["event_id", "mother_nhs_number", "event_type", "event_date"]}]}, "ontology": {"classes": [{"name": "Mother"}, {"name": "Baby"}, {"name": "Pregnancy"}, {"name": "Delivery"}, {"name": "AntenatalContact"}, {"name": "Labour"}, {"name": "Postnatal"}, {"name": "ClinicalFinding"}], "properties": [{"name": "hasBaby", "domain": "Mother", "range": "Baby"}, {"name": "hasPregnancy", "domain": "Mother", "range": "Pregnancy"}, {"name": "hasDelivery", "domain": "Pregnancy", "range": "Delivery"}, {"name": "hasContact", "domain": "Pregnancy", "range": "AntenatalContact"}, {"name": "hasLabour", "domain": "Pregnancy", "range": "Labour"}, {"name": "hasFinding", "domain": "Pregnancy", "range": "ClinicalFinding"}]}}, "expected": {"recommended_engine": "pge"}, "tags": ["complex", "cross-source", "heterogeneous"]} +{"id": "complex-multi-source-shared-key-012", "input": {"metadata": {"tables": [{"name": "crm_customer", "columns": ["customer_id", "full_name", "email"]}, {"name": "erp_customer", "columns": ["CUSTOMER_ID", "NAME", "TAX_CODE"]}, {"name": "billing_customer", "columns": ["customer_id", "billing_email", "credit_limit"]}]}, "ontology": {"classes": [{"name": "Customer"}, {"name": "Invoice"}, {"name": "Address"}, {"name": "Account"}], "properties": [{"name": "hasInvoice", "domain": "Customer", "range": "Invoice"}, {"name": "hasAddress", "domain": "Customer", "range": "Address"}, {"name": "hasAccount", "domain": "Customer", "range": "Account"}]}}, "expected": {"recommended_engine": "pge"}, "tags": ["complex", "cross-source", "heterogeneous"]} +{"id": "complex-large-ontology-013", "input": {"metadata": {"tables": [{"name": "claims", "columns": ["claim_id", "member_id", "provider_id", "amount", "diagnosis_code"]}, {"name": "members", "columns": ["member_id", "name", "plan_id"]}, {"name": "providers", "columns": ["provider_id", "provider_name", "specialty"]}]}, "ontology": {"classes": [{"name": "Claim"}, {"name": "Member"}, {"name": "Provider"}, {"name": "Plan"}, {"name": "Diagnosis"}, {"name": "Procedure"}, {"name": "Encounter"}, {"name": "Payment"}, {"name": "Pharmacy"}, {"name": "Prescription"}, {"name": "Facility"}, {"name": "Coverage"}], "properties": [{"name": "filedBy", "domain": "Claim", "range": "Member"}, {"name": "treatedBy", "domain": "Claim", "range": "Provider"}, {"name": "hasDiagnosis", "domain": "Claim", "range": "Diagnosis"}, {"name": "underPlan", "domain": "Member", "range": "Plan"}, {"name": "atFacility", "domain": "Encounter", "range": "Facility"}]}}, "expected": {"recommended_engine": "pge"}, "tags": ["complex", "cross-source", "large-ontology"]} +{"id": "complex-heterogeneous-naming-014", "input": {"metadata": {"tables": [{"name": "src_legacy", "columns": ["CUST_ID", "FIRST_NM", "LAST_NM"]}, {"name": "src_modern", "columns": ["customerId", "firstName", "lastName"]}, {"name": "src_warehouse", "columns": ["cust_id", "first_name", "last_name"]}]}, "ontology": {"classes": [{"name": "Customer"}, {"name": "Order"}, {"name": "LineItem"}, {"name": "Product"}, {"name": "Shipment"}], "properties": [{"name": "places", "domain": "Customer", "range": "Order"}, {"name": "contains", "domain": "Order", "range": "LineItem"}, {"name": "refersTo", "domain": "LineItem", "range": "Product"}, {"name": "shippedVia", "domain": "Order", "range": "Shipment"}]}}, "expected": {"recommended_engine": "pge"}, "tags": ["complex", "heterogeneous", "cross-source"]} +{"id": "complex-five-tables-015", "input": {"metadata": {"tables": [{"name": "t1", "columns": ["entity_id", "a"]}, {"name": "t2", "columns": ["entity_id", "b"]}, {"name": "t3", "columns": ["entity_id", "c"]}, {"name": "t4", "columns": ["entity_id", "d"]}, {"name": "t5", "columns": ["entity_id", "e"]}]}, "ontology": {"classes": [{"name": "Entity"}, {"name": "Event"}, {"name": "Attribute"}], "properties": [{"name": "hasEvent", "domain": "Entity", "range": "Event"}, {"name": "hasAttribute", "domain": "Entity", "range": "Attribute"}]}}, "expected": {"recommended_engine": "pge"}, "tags": ["complex", "cross-source"]} +{"id": "complex-patient-multi-source-016", "input": {"metadata": {"tables": [{"name": "ehr_patient", "columns": ["patient_id", "mrn", "dob"]}, {"name": "lab_results", "columns": ["patient_id", "test_code", "value"]}, {"name": "pharmacy_dispense", "columns": ["patient_id", "drug_code", "dispense_date"]}, {"name": "gp_record", "columns": ["PATIENT_ID", "NHS_NUMBER", "REGISTERED_DATE"]}]}, "ontology": {"classes": [{"name": "Patient"}, {"name": "LabResult"}, {"name": "Dispense"}, {"name": "Drug"}, {"name": "GPRegistration"}, {"name": "Observation"}, {"name": "Condition"}, {"name": "Encounter"}], "properties": [{"name": "hasLabResult", "domain": "Patient", "range": "LabResult"}, {"name": "hasDispense", "domain": "Patient", "range": "Dispense"}, {"name": "dispensed", "domain": "Dispense", "range": "Drug"}, {"name": "registeredAt", "domain": "Patient", "range": "GPRegistration"}, {"name": "hasObservation", "domain": "Patient", "range": "Observation"}]}}, "expected": {"recommended_engine": "pge"}, "tags": ["complex", "cross-source", "large-ontology", "heterogeneous"]} +{"id": "complex-shared-code-key-017", "input": {"metadata": {"tables": [{"name": "sales_us", "columns": ["product_code", "region", "units"]}, {"name": "sales_eu", "columns": ["product_code", "country", "units"]}, {"name": "catalog", "columns": ["product_code", "name", "category"]}]}, "ontology": {"classes": [{"name": "Product"}, {"name": "Sale"}, {"name": "Region"}, {"name": "Category"}, {"name": "Channel"}, {"name": "Customer"}], "properties": [{"name": "soldAs", "domain": "Product", "range": "Sale"}, {"name": "inRegion", "domain": "Sale", "range": "Region"}, {"name": "ofCategory", "domain": "Product", "range": "Category"}, {"name": "viaChannel", "domain": "Sale", "range": "Channel"}]}}, "expected": {"recommended_engine": "pge"}, "tags": ["complex", "cross-source"]} +{"id": "complex-many-rels-018", "input": {"metadata": {"tables": [{"name": "fact_a", "columns": ["record_id", "dim1_id", "dim2_id"]}, {"name": "fact_b", "columns": ["record_id", "dim3_id", "measure"]}, {"name": "dims", "columns": ["RECORD_ID", "label"]}]}, "ontology": {"classes": [{"name": "Record"}, {"name": "Dim1"}, {"name": "Dim2"}, {"name": "Dim3"}, {"name": "Measure"}, {"name": "Fact"}, {"name": "Time"}, {"name": "Geography"}, {"name": "Product"}, {"name": "Channel"}], "properties": [{"name": "hasDim1", "domain": "Record", "range": "Dim1"}, {"name": "hasDim2", "domain": "Record", "range": "Dim2"}, {"name": "hasDim3", "domain": "Record", "range": "Dim3"}, {"name": "hasMeasure", "domain": "Fact", "range": "Measure"}, {"name": "atTime", "domain": "Fact", "range": "Time"}, {"name": "atGeo", "domain": "Fact", "range": "Geography"}]}}, "expected": {"recommended_engine": "pge"}, "tags": ["complex", "cross-source", "large-ontology"]} +{"id": "complex-finance-multi-019", "input": {"metadata": {"tables": [{"name": "trades_fix", "columns": ["TRADE_ID", "ACCOUNT_NO", "SYMBOL", "QTY"]}, {"name": "positions", "columns": ["account_no", "symbol", "quantity"]}, {"name": "accounts", "columns": ["account_no", "holder_name", "open_date"]}]}, "ontology": {"classes": [{"name": "Trade"}, {"name": "Account"}, {"name": "Position"}, {"name": "Instrument"}, {"name": "Holder"}, {"name": "Settlement"}], "properties": [{"name": "onAccount", "domain": "Trade", "range": "Account"}, {"name": "ofInstrument", "domain": "Trade", "range": "Instrument"}, {"name": "heldBy", "domain": "Account", "range": "Holder"}, {"name": "hasPosition", "domain": "Account", "range": "Position"}, {"name": "settledBy", "domain": "Trade", "range": "Settlement"}]}}, "expected": {"recommended_engine": "pge"}, "tags": ["complex", "cross-source", "heterogeneous"]} +{"id": "complex-iot-multi-source-020", "input": {"metadata": {"tables": [{"name": "sensor_stream", "columns": ["device_id", "ts", "reading"]}, {"name": "device_registry", "columns": ["DEVICE_ID", "model", "install_date"]}, {"name": "maintenance_log", "columns": ["device_id", "tech_id", "service_date"]}]}, "ontology": {"classes": [{"name": "Device"}, {"name": "Reading"}, {"name": "MaintenanceEvent"}, {"name": "Technician"}, {"name": "Site"}, {"name": "Alert"}, {"name": "Firmware"}], "properties": [{"name": "emits", "domain": "Device", "range": "Reading"}, {"name": "serviced", "domain": "Device", "range": "MaintenanceEvent"}, {"name": "performedBy", "domain": "MaintenanceEvent", "range": "Technician"}, {"name": "installedAt", "domain": "Device", "range": "Site"}, {"name": "raises", "domain": "Device", "range": "Alert"}]}}, "expected": {"recommended_engine": "pge"}, "tags": ["complex", "cross-source", "heterogeneous"]} From 795622f7653274881f6b5d3e9026606c81005f38 Mon Sep 17 00:00:00 2001 From: Fiifi Botchway Date: Thu, 25 Jun 2026 17:37:28 +0100 Subject: [PATCH 4/4] refactor(agents): simplify supervisor engine for reviewability MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Behavior-preserving post-review pass on the new supervisor code: - engine.py: Remove Middle Man — drop _run_mapping(**kw) pack/unpack indirection; select run_mapping_pge vs run_auto_assignment inline. - responses_agent.py: compute assess() only in the branch that uses it; Optional[dict] type hint. Tests: tests/agents/agent_supervisor 35 passed (unchanged). Co-authored-by: Isaac --- changelogs/v0.5.2/FiifiB_2026-06-25.log | 21 +++++++++++++++ src/agents/agent_supervisor/engine.py | 27 +++++-------------- .../agent_supervisor/responses_agent.py | 12 ++++----- 3 files changed, 33 insertions(+), 27 deletions(-) diff --git a/changelogs/v0.5.2/FiifiB_2026-06-25.log b/changelogs/v0.5.2/FiifiB_2026-06-25.log index f000a97b..8c73edc6 100644 --- a/changelogs/v0.5.2/FiifiB_2026-06-25.log +++ b/changelogs/v0.5.2/FiifiB_2026-06-25.log @@ -158,3 +158,24 @@ on it. (Stacked on PR1 + PR2.) routing-accuracy 20/20 and Python↔SQL constant parity). - Full stacked-branch regression `tests/agents tests/units/{agents,mapping,pge_eval,ontology}` → **759 passed, 11 skipped**. + +# 2026-06-25 — refactor(agents): simplify supervisor engine for reviewability + +## Context + +Post-review simplification pass on the new supervisor code (behavior-preserving). + +## Changes + +1. `src/agents/agent_supervisor/engine.py` — Remove Middle Man: deleted the + `_run_mapping(**kw)` indirection that packed→unpacked→repacked identical + kwargs; `run()` now selects `run_mapping_pge` vs `run_auto_assignment` inline + (−15 lines), keeping the dispatch decision beside its call. +2. `src/agents/agent_supervisor/responses_agent.py` — moved the `assess()` call + into the `assess` branch that consumes it (the `run` path recomputes it in + `SupervisorEngine.run`), inlined a one-use local, tightened `_text_event` + `custom_outputs` to `Optional[dict]`. + +## Tests + +`uv run pytest tests/agents/agent_supervisor -q` → **35 passed** (unchanged). diff --git a/src/agents/agent_supervisor/engine.py b/src/agents/agent_supervisor/engine.py index f98642dc..7d55ee82 100644 --- a/src/agents/agent_supervisor/engine.py +++ b/src/agents/agent_supervisor/engine.py @@ -123,9 +123,12 @@ def run( try: if task == "mapping": - result = SupervisorEngine._run_mapping( - agent, - engine, + run_engine = ( + agent.run_mapping_pge + if engine == "pge" + else agent.run_auto_assignment + ) + result = run_engine( host=host, token=token, endpoint_name=endpoint_name, @@ -167,21 +170,3 @@ def run( result=result, success=bool(getattr(result, "success", True)), ) - - @staticmethod - def _run_mapping(agent, engine, **kw): - common = dict( - host=kw["host"], - token=kw["token"], - endpoint_name=kw["endpoint_name"], - client=kw["client"], - metadata=kw["metadata"], - ontology=kw["ontology"], - entity_mappings=kw["entity_mappings"], - relationship_mappings=kw["relationship_mappings"], - documents=kw["documents"], - on_step=kw["on_step"], - ) - if engine == "pge": - return agent.run_mapping_pge(**common) - return agent.run_auto_assignment(**common) diff --git a/src/agents/agent_supervisor/responses_agent.py b/src/agents/agent_supervisor/responses_agent.py index d6bf7735..268430f5 100644 --- a/src/agents/agent_supervisor/responses_agent.py +++ b/src/agents/agent_supervisor/responses_agent.py @@ -25,7 +25,7 @@ """ import copy -from typing import Generator +from typing import Generator, Optional from uuid import uuid4 import mlflow @@ -75,10 +75,8 @@ def predict_stream( ontology = copy.deepcopy(ci.get("ontology", {})) mode = ci.get("mode") or ("run" if ci.get("client") is not None else "assess") - report = assess(metadata, ontology) - engine = ci.get("engine_override") or self._engine - if mode == "assess": + report = assess(metadata, ontology) text = ( f"Complexity {report.score:.2f} ({report.tier}). " f"Recommended engine: {report.recommended_engine}. {report.rationale}" @@ -99,7 +97,7 @@ def predict_stream( endpoint_name=ci["endpoint_name"], metadata=metadata, ontology=ontology, - engine_override=engine, + engine_override=ci.get("engine_override") or self._engine, client=ci.get("client"), entity_mappings=ci.get("entity_mappings"), relationship_mappings=ci.get("relationship_mappings"), @@ -110,7 +108,9 @@ def predict_stream( custom_outputs=result.to_dict(), ) - def _text_event(self, text: str, custom_outputs: dict = None) -> ResponsesAgentStreamEvent: + def _text_event( + self, text: str, custom_outputs: Optional[dict] = None + ) -> ResponsesAgentStreamEvent: return ResponsesAgentStreamEvent( type="response.output_item.done", item=self.create_text_output_item(text=text, id=f"msg_{uuid4().hex[:8]}"),