From 7895c3a569ae5dbf4759134912556b85097ff37c Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Sun, 3 May 2026 10:39:40 -0400 Subject: [PATCH] refactor: introduce Predicate AST for filter expressions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three places in ``query.py`` walked filter dicts to build ibis expressions: ``Filter._parse_json_filter`` (pre-agg, attribute access), ``_build_post_agg_predicate`` (post-agg, bracket access), and ``_extract_filter_fields`` (collect referenced fields). They had diverged in subtle ways (e.g. how AND/OR chain) and there was no shared representation that could be inspected, optimized, or serialized. Introduce a small AST in ``boring_semantic_layer.predicate``: Predicate = Compare | In | IsNull | And | Or | Not | Custom with ``from_dict`` (parse a JSON spec), ``compile`` (turn into an ibis expression for either pre- or post-agg), and ``fields`` (collect referenced names). All three of the previous walkers now delegate. ``Filter`` keeps accepting dicts, strings, and callables; reflecting callables into ``Predicate`` is a follow-up. ``SemanticFilterOp`` is unchanged. Drops the now-dead ``OPERATOR_MAPPING`` table, the ``_ibis_*`` two-arg helpers, ``Filter.OPERATORS`` ClassVar, ``Filter._get_field_expr``, and the ``operator`` module imports — none had external references. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/boring_semantic_layer/predicate.py | 286 ++++++++++++++++++ src/boring_semantic_layer/query.py | 190 ++---------- .../tests/test_predicate.py | 169 +++++++++++ 3 files changed, 475 insertions(+), 170 deletions(-) create mode 100644 src/boring_semantic_layer/predicate.py create mode 100644 src/boring_semantic_layer/tests/test_predicate.py diff --git a/src/boring_semantic_layer/predicate.py b/src/boring_semantic_layer/predicate.py new file mode 100644 index 0000000..edde3cd --- /dev/null +++ b/src/boring_semantic_layer/predicate.py @@ -0,0 +1,286 @@ +"""Internal Predicate AST for filter expressions. + +A small algebra over filters that gives every supported operator a single +canonical representation. Replaces the ad-hoc walk-the-dict pattern that +existed in three places: + +- ``query.Filter._parse_json_filter`` (pre-aggregation, attribute access) +- ``query._build_post_agg_predicate`` (post-aggregation, bracket access) +- ``query._extract_filter_fields`` (collect referenced fields) + +JSON filter specs (and string specs, eventually) parse into ``Predicate``; +the same compiler turns ``Predicate`` into an ibis expression for either +pre- or post-aggregation tables. ``SemanticFilterOp`` continues to accept +opaque callables — reflecting callables into ``Predicate`` is a later +step. +""" + +from __future__ import annotations + +from collections.abc import Callable, Iterable +from typing import Any, ClassVar, Literal + +import ibis +from attrs import field, frozen + + +_COMPARE_OPS: dict[str, Callable[[Any, Any], Any]] = { + "eq": lambda x, y: x == y, + "ne": lambda x, y: x != y, + "lt": lambda x, y: x < y, + "le": lambda x, y: x <= y, + "gt": lambda x, y: x > y, + "ge": lambda x, y: x >= y, + "like": lambda x, y: x.like(y), + "not_like": lambda x, y: ~x.like(y), + "ilike": lambda x, y: x.ilike(y), + "not_ilike": lambda x, y: ~x.ilike(y), +} + +# JSON filter operator strings that map to a Compare node. Includes +# legacy aliases (``=``, ``equals``) accepted by the existing parser. +_DICT_COMPARE_OPS: dict[str, str] = { + "=": "eq", + "eq": "eq", + "equals": "eq", + "!=": "ne", + ">": "gt", + ">=": "ge", + "<": "lt", + "<=": "le", + "like": "like", + "not like": "not_like", + "ilike": "ilike", + "not ilike": "not_ilike", +} + + +@frozen +class Compare: + """Two-arg comparison: field value.""" + + op: Literal["eq", "ne", "lt", "le", "gt", "ge", "like", "not_like", "ilike", "not_ilike"] + field: str + value: Any + + def fields(self) -> set[str]: + return {self.field} + + +@frozen +class In: + """Membership test: ``field in values`` (or ``not in`` when negated).""" + + field: str + values: tuple = field(converter=tuple) + negate: bool = False + + def fields(self) -> set[str]: + return {self.field} + + +@frozen +class IsNull: + """Null check (or not-null when negated).""" + + field: str + negate: bool = False + + def fields(self) -> set[str]: + return {self.field} + + +@frozen +class And: + """Conjunction of one or more predicates.""" + + children: tuple = field(converter=tuple) + + def fields(self) -> set[str]: + return set().union(*(c.fields() for c in self.children)) + + +@frozen +class Or: + """Disjunction of one or more predicates.""" + + children: tuple = field(converter=tuple) + + def fields(self) -> set[str]: + return set().union(*(c.fields() for c in self.children)) + + +@frozen +class Not: + """Negation.""" + + predicate: Any + + def fields(self) -> set[str]: + return self.predicate.fields() + + +@frozen +class Custom: + """Escape hatch for callables that can't be reflected into the AST.""" + + fn: Callable + + def fields(self) -> set[str]: + return set() + + +Predicate = Compare | In | IsNull | And | Or | Not | Custom + + +_COMPOUND_OPS: ClassVar = frozenset({"AND", "OR"}) + + +def from_dict(spec: dict[str, Any]) -> Predicate: + """Parse a JSON-style filter spec into a ``Predicate``. + + Mirrors the schema that ``query.Filter`` accepts. Unknown operators + raise ``ValueError`` rather than falling through silently. + """ + if not isinstance(spec, dict): + raise ValueError(f"Filter spec must be a dict, got {type(spec).__name__}") + + op = spec.get("operator") + if op is None: + raise KeyError( + "Missing required keys in filter: 'field' and 'operator' are required" + ) + + if op == "AND": + return And(children=tuple(from_dict(c) for c in _require_conditions(spec, op))) + if op == "OR": + return Or(children=tuple(from_dict(c) for c in _require_conditions(spec, op))) + + field_name = spec.get("field") + if field_name is None: + raise KeyError( + "Missing required keys in filter: 'field' and 'operator' are required" + ) + + if op == "is null": + _reject_value_keys(spec, op) + return IsNull(field=field_name, negate=False) + if op == "is not null": + _reject_value_keys(spec, op) + return IsNull(field=field_name, negate=True) + + if op == "in": + return In(field=field_name, values=_require_values(spec, op), negate=False) + if op == "not in": + return In(field=field_name, values=_require_values(spec, op), negate=True) + + canonical = _DICT_COMPARE_OPS.get(op) + if canonical is None: + raise ValueError(f"Unsupported operator: {op}") + + if "value" not in spec: + raise ValueError(f"Operator {op!r} requires 'value' field") + return Compare(op=canonical, field=field_name, value=spec["value"]) + + +def _require_conditions(spec: dict, op: str) -> Iterable[dict]: + conditions = spec.get("conditions") + if not conditions: + raise ValueError(f"Compound filter {op!r} must have non-empty 'conditions' list") + return conditions + + +def _require_values(spec: dict, op: str) -> tuple: + values = spec.get("values") + if values is None: + raise ValueError(f"Operator {op!r} requires 'values' field") + return tuple(values) + + +def _reject_value_keys(spec: dict, op: str) -> None: + if any(k in spec for k in ("value", "values")): + raise ValueError(f"Operator {op!r} should not have 'value' or 'values' fields") + + +def _convert_literal(value: Any, ibis_module) -> Any: + """Convert string date/timestamp values to typed ibis literals. + + Mirrors ``query.Filter._convert_filter_value``: backends like Athena + require typed date literals or fail with TYPE_MISMATCH. Returns the + value unchanged when it is not a date/timestamp string. + """ + if not isinstance(value, str): + return value + for dtype in ("timestamp", "date"): + try: + return ibis_module.literal(value, type=dtype) + except (ValueError, TypeError): + pass + return value + + +def _field_accessor(table, name: str, *, post_agg: bool): + """Resolve a field name on the table. + + Pre-aggregation: use attribute access on ``ibis._`` so the predicate + can be resolved against any table later (it is a Deferred). + Pre-aggregation also strips a model-prefix from dotted names because + joined tables flatten columns to the top level. + + Post-aggregation: use bracket access to preserve dotted names like + ``orders.total_amount`` that survive into the aggregated table. + """ + if post_agg: + return table[name] + if "." in name: + _prefix, unprefixed = name.split(".", 1) + return getattr(table, unprefixed) + return getattr(table, name) + + +def compile( # noqa: A001 + pred: Predicate, + table, + *, + post_agg: bool = False, + ibis_module=ibis, +) -> Any: + """Compile *pred* into an ibis expression against *table*. + + *table* is typically a Deferred (``ibis._``) for pre-agg or an actual + aggregated relation for post-agg. ``ibis_module`` controls the flavor + of literal construction (plain ibis vs xorq vendored). + """ + if isinstance(pred, And): + compiled = [compile(c, table, post_agg=post_agg, ibis_module=ibis_module) for c in pred.children] + result = compiled[0] + for c in compiled[1:]: + result = result & c + return result + if isinstance(pred, Or): + compiled = [compile(c, table, post_agg=post_agg, ibis_module=ibis_module) for c in pred.children] + result = compiled[0] + for c in compiled[1:]: + result = result | c + return result + if isinstance(pred, Not): + return ~compile(pred.predicate, table, post_agg=post_agg, ibis_module=ibis_module) + if isinstance(pred, IsNull): + col = _field_accessor(table, pred.field, post_agg=post_agg) + return col.notnull() if pred.negate else col.isnull() + if isinstance(pred, In): + col = _field_accessor(table, pred.field, post_agg=post_agg) + values = [_convert_literal(v, ibis_module) for v in pred.values] + return col.notin(values) if pred.negate else col.isin(values) + if isinstance(pred, Compare): + col = _field_accessor(table, pred.field, post_agg=post_agg) + value = _convert_literal(pred.value, ibis_module) + return _COMPARE_OPS[pred.op](col, value) + if isinstance(pred, Custom): + return pred.fn(table) + raise TypeError(f"Unknown predicate node: {type(pred).__name__}") + + +def fields(pred: Predicate) -> set[str]: + """Return the set of field names referenced by *pred*.""" + return pred.fields() diff --git a/src/boring_semantic_layer/query.py b/src/boring_semantic_layer/query.py index 5c20027..5e214f2 100644 --- a/src/boring_semantic_layer/query.py +++ b/src/boring_semantic_layer/query.py @@ -7,7 +7,6 @@ from __future__ import annotations from collections.abc import Callable, Mapping, Sequence -from operator import eq, ge, gt, le, lt, ne from typing import Any, ClassVar, Literal import ibis @@ -71,68 +70,7 @@ def _get_ibis_api(): ) -# Helper functions using operator module instead of lambdas -def _ibis_isin(x, y): - return x.isin(y) - - -def _ibis_not_isin(x, y): - return ~x.isin(y) - - -def _ibis_like(x, y): - return x.like(y) - - -def _ibis_not_like(x, y): - return ~x.like(y) - - -def _ibis_ilike(x, y): - return x.ilike(y) - - -def _ibis_not_ilike(x, y): - return ~x.ilike(y) - - -def _ibis_isnull(x, _): - return x.isnull() - - -def _ibis_notnull(x, _): - return x.notnull() - - -def _ibis_and(x, y): - return x & y - - -def _ibis_or(x, y): - return x | y - - -# Operator mapping using operator module functions where possible -OPERATOR_MAPPING: FrozenDict = { - "=": eq, - "eq": eq, - "equals": eq, - "!=": ne, - ">": gt, - ">=": ge, - "<": lt, - "<=": le, - "in": _ibis_isin, - "not in": _ibis_not_isin, - "like": _ibis_like, - "not like": _ibis_not_like, - "ilike": _ibis_ilike, - "not ilike": _ibis_not_ilike, - "is null": _ibis_isnull, - "is not null": _ibis_notnull, - "AND": _ibis_and, - "OR": _ibis_or, -} +# Filter parsing and compilation lives in ``boring_semantic_layer.predicate``. @curry @@ -233,7 +171,6 @@ class Filter: filter: FrozenDict | str | Callable - OPERATORS: ClassVar[set] = set(OPERATOR_MAPPING.keys()) COMPOUND_OPERATORS: ClassVar[set] = {"AND", "OR"} def __attrs_post_init__(self) -> None: @@ -261,78 +198,19 @@ def _convert_filter_value(self, value: Any) -> Any: # Not a date/timestamp, return original value return value - def _get_field_expr(self, field: str) -> Any: - """Get field expression using ibis._ for unbound reference. - - For prefixed fields (e.g., 'customers.country'), use only the field name - since joined tables flatten the columns to the top level. - """ - _ibis = _get_ibis_api() - if "." in field: - # Extract just the field name, ignoring the table prefix - # e.g., 'customers.country' -> 'country' - _table_name, field_name = field.split(".", 1) - return getattr(_ibis._, field_name) - return getattr(_ibis._, field) - - def _parse_json_filter(self, filter_obj: FrozenDict) -> Any: - """Parse JSON filter object into ibis expression.""" - # Compound filters (AND/OR) - if filter_obj.get("operator") in self.COMPOUND_OPERATORS: - conditions = filter_obj.get("conditions") - if not conditions: - raise ValueError("Compound filter must have non-empty conditions list") - expr = self._parse_json_filter(conditions[0]) - for cond in conditions[1:]: - next_expr = self._parse_json_filter(cond) - expr = OPERATOR_MAPPING[filter_obj["operator"]](expr, next_expr) - return expr - - # Simple filter - field = filter_obj.get("field") - op = filter_obj.get("operator") - if field is None or op is None: - raise KeyError( - "Missing required keys in filter: 'field' and 'operator' are required", - ) - - field_expr = self._get_field_expr(field) - - if op not in self.OPERATORS: - raise ValueError(f"Unsupported operator: {op}") - - # List membership operators - if op in ("in", "not in"): - values = filter_obj.get("values") - if values is None: - raise ValueError(f"Operator '{op}' requires 'values' field") - # Convert each value for date/timestamp support - converted_values = [self._convert_filter_value(v) for v in values] - return OPERATOR_MAPPING[op](field_expr, converted_values) - - # Null checks - if op in ("is null", "is not null"): - if any(k in filter_obj for k in ("value", "values")): - raise ValueError( - f"Operator '{op}' should not have 'value' or 'values' fields", - ) - return OPERATOR_MAPPING[op](field_expr, None) - - # Single value operators - value = filter_obj.get("value") - if value is None: - raise ValueError(f"Operator '{op}' requires 'value' field") - # Convert value for date/timestamp support - converted_value = self._convert_filter_value(value) - return OPERATOR_MAPPING[op](field_expr, converted_value) - def to_callable(self) -> Callable: """Convert filter to callable that can be used with SemanticTable.filter().""" + from . import predicate as pred_mod from .ops import _ensure_xorq_table if isinstance(self.filter, dict): - expr = self._parse_json_filter(self.filter) - return lambda t: expr.resolve(_ensure_xorq_table(t)) + pred = pred_mod.from_dict(self.filter) + ibis_module = _get_ibis_api() + return lambda t: pred_mod.compile( + pred, + ibis_module._, + ibis_module=ibis_module, + ).resolve(_ensure_xorq_table(t)) elif isinstance(self.filter, str): _ibis = _get_ibis_api() expr = safe_eval( @@ -414,15 +292,11 @@ def _normalize_order_by( def _extract_filter_fields(filter_spec: dict) -> set[str]: """Extract all field names referenced by a dict filter (including compound).""" + from . import predicate as pred_mod + if not isinstance(filter_spec, dict): return set() - if filter_spec.get("operator") in ("AND", "OR"): - fields: set[str] = set() - for cond in filter_spec.get("conditions", []): - fields |= _extract_filter_fields(cond) - return fields - field = filter_spec.get("field") - return {field} if field else set() + return pred_mod.fields(pred_mod.from_dict(filter_spec)) def _normalize_filter_fields( @@ -448,40 +322,16 @@ def _normalize_filter_fields( def _build_post_agg_predicate(filter_obj: dict) -> Any: - """Build an ibis predicate for post-aggregation filters. + """Build an ibis predicate (``Deferred``) for post-aggregation filters. - Uses bracket access (``t[field]``) instead of attribute access so that - dotted column names from joined models (e.g. ``orders.total_amount``) - resolve correctly on the aggregated table. + Delegates to the ``Predicate`` AST in ``predicate``. Bracket-access + field resolution preserves dotted names from joined models (e.g. + ``orders.total_amount``) on the aggregated table. """ - if filter_obj.get("operator") in ("AND", "OR"): - conditions = filter_obj.get("conditions", []) - expr = _build_post_agg_predicate(conditions[0]) - for cond in conditions[1:]: - next_expr = _build_post_agg_predicate(cond) - expr = OPERATOR_MAPPING[filter_obj["operator"]](expr, next_expr) - return expr - - field = filter_obj["field"] - op = filter_obj["operator"] - # Use bracket access on ibis._ to preserve dotted names - field_expr = ibis._[field] - - if op in ("is null", "is not null"): - return OPERATOR_MAPPING[op](field_expr, None) - if op in ("in", "not in"): - return OPERATOR_MAPPING[op](field_expr, filter_obj.get("values", [])) - - value = filter_obj.get("value") - # Convert date/timestamp strings - if isinstance(value, str): - for dtype in ("timestamp", "date"): - try: - value = ibis.literal(value, type=dtype) - break - except (ValueError, TypeError): - pass - return OPERATOR_MAPPING[op](field_expr, value) + from . import predicate as pred_mod + + pred = pred_mod.from_dict(filter_obj) + return pred_mod.compile(pred, ibis._, post_agg=True, ibis_module=ibis) def _normalize_post_agg_filter( diff --git a/src/boring_semantic_layer/tests/test_predicate.py b/src/boring_semantic_layer/tests/test_predicate.py new file mode 100644 index 0000000..637e7eb --- /dev/null +++ b/src/boring_semantic_layer/tests/test_predicate.py @@ -0,0 +1,169 @@ +"""Unit tests for the internal Predicate AST.""" + +from __future__ import annotations + +import ibis +import pandas as pd +import pytest + +from boring_semantic_layer import predicate as pred_mod +from boring_semantic_layer.predicate import And, Compare, In, IsNull, Not, Or + + +# --------------------------------------------------------------------------- +# from_dict: parsing the JSON filter spec +# --------------------------------------------------------------------------- + + +def test_from_dict_simple_eq(): + p = pred_mod.from_dict({"operator": "=", "field": "country", "value": "US"}) + assert p == Compare(op="eq", field="country", value="US") + + +def test_from_dict_canonicalizes_aliases(): + assert pred_mod.from_dict( + {"operator": "equals", "field": "x", "value": 1} + ) == Compare(op="eq", field="x", value=1) + assert pred_mod.from_dict( + {"operator": "!=", "field": "x", "value": 1} + ) == Compare(op="ne", field="x", value=1) + + +def test_from_dict_in_and_not_in(): + p = pred_mod.from_dict({"operator": "in", "field": "tier", "values": ["a", "b"]}) + assert p == In(field="tier", values=("a", "b"), negate=False) + + p = pred_mod.from_dict( + {"operator": "not in", "field": "tier", "values": ["a"]} + ) + assert p == In(field="tier", values=("a",), negate=True) + + +def test_from_dict_null_checks_reject_value_keys(): + pred_mod.from_dict({"operator": "is null", "field": "x"}) # ok + with pytest.raises(ValueError, match="should not have"): + pred_mod.from_dict({"operator": "is null", "field": "x", "value": 1}) + + +def test_from_dict_compound(): + p = pred_mod.from_dict( + { + "operator": "AND", + "conditions": [ + {"operator": "=", "field": "a", "value": 1}, + {"operator": ">", "field": "b", "value": 0}, + ], + } + ) + assert isinstance(p, And) + assert len(p.children) == 2 + + +def test_from_dict_rejects_empty_compound(): + with pytest.raises(ValueError, match="non-empty"): + pred_mod.from_dict({"operator": "AND", "conditions": []}) + + +def test_from_dict_unsupported_operator(): + with pytest.raises(ValueError, match="Unsupported operator"): + pred_mod.from_dict({"operator": "WAT", "field": "x", "value": 1}) + + +# --------------------------------------------------------------------------- +# fields: collect referenced field names +# --------------------------------------------------------------------------- + + +def test_fields_simple(): + p = Compare(op="eq", field="country", value="US") + assert pred_mod.fields(p) == {"country"} + + +def test_fields_compound(): + p = And( + children=( + Compare(op="eq", field="a", value=1), + Or( + children=( + Compare(op="gt", field="b", value=0), + IsNull(field="c"), + ) + ), + ) + ) + assert pred_mod.fields(p) == {"a", "b", "c"} + + +def test_fields_not_passes_through(): + p = Not(predicate=Compare(op="eq", field="x", value=1)) + assert pred_mod.fields(p) == {"x"} + + +# --------------------------------------------------------------------------- +# compile: round-trip via an in-memory duckdb table +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def people_table(): + con = ibis.duckdb.connect() + df = pd.DataFrame( + { + "country": ["US", "US", "FR", "DE", None], + "age": [10, 25, 30, 45, 50], + } + ) + return con.create_table("people", df) + + +def _execute(pred, table): + expr = pred_mod.compile(pred, ibis._) + return table.filter(expr.resolve(table)).execute() + + +def test_compile_eq(people_table): + df = _execute(Compare(op="eq", field="country", value="US"), people_table) + assert sorted(df["age"].tolist()) == [10, 25] + + +def test_compile_in(people_table): + df = _execute( + In(field="country", values=("FR", "DE"), negate=False), people_table + ) + assert sorted(df["age"].tolist()) == [30, 45] + + +def test_compile_isnull(people_table): + df = _execute(IsNull(field="country"), people_table) + assert df["age"].tolist() == [50] + + +def test_compile_and(people_table): + p = And( + children=( + Compare(op="eq", field="country", value="US"), + Compare(op="gt", field="age", value=20), + ) + ) + df = _execute(p, people_table) + assert df["age"].tolist() == [25] + + +def test_compile_or_and_not(people_table): + p = Or( + children=( + Compare(op="eq", field="country", value="FR"), + Not(predicate=Compare(op="lt", field="age", value=40)), + ) + ) + df = _execute(p, people_table) + # FR(30) plus age >= 40 (45, 50) + assert sorted(df["age"].tolist()) == [30, 45, 50] + + +def test_compile_post_agg_uses_bracket_access(): + """Post-agg compilation preserves dotted column names.""" + p = Compare(op="gt", field="orders.total", value=100) + expr = pred_mod.compile(p, ibis._, post_agg=True) + # Sanity: it is a Deferred — actual semantic test is covered elsewhere + assert hasattr(expr, "resolve")