From a49680f7d9ec4e02d47755bc22566836a0fa6279 Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Mon, 4 May 2026 07:31:18 -0400 Subject: [PATCH 1/6] refactor: convert ops.py to package and extract self-contained sections Convert src/boring_semantic_layer/ops.py to a package (ops/) and extract two cleanly-bounded chunks into submodules: - ops/_values.py (166 lines): Dimension, Measure value objects plus the prefix proxies that support model-prefix navigation in joined dimension lambdas. Imports nothing from the relational core; the only callback to _core (_format_column_error in the AttributeError path) is a lazy in-method import to avoid circular module loading. - ops/_column_extraction.py (369 lines): ColumnTracker, ColumnExtractionResult, JoinColumnExtractionResult, TableColumnRequirements and the projection-pushdown extraction helpers. Self-contained except for one lazy _unwrap import inside _extract_requirements_from_measures. ops/__init__.py re-exports every name external callers reference, so all existing `from boring_semantic_layer.ops import X` imports keep working unchanged. _core.py imports the moved names back so its own internal call sites are unchanged too. The remaining 4,860 lines in _core.py are the relational op classes (Semantic*Op) and their tightly-coupled helpers (root-model traversal, join-key/rename machinery, aggregation planning). Those resist a clean split: helpers reference SemanticTableOp which is defined mid-file, and the join logic threads through several intricate fixes documented in MEMORY.md (multi-way join column ambiguity, _RenamedResolver, pre-aggregation join direction). Splitting them further would risk subtle behavior changes for limited structural gain. A future pass could pull these apart but it warrants its own focused branch. No behavior change. 930 tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/boring_semantic_layer/ops/__init__.py | 67 ++ .../ops/_column_extraction.py | 369 +++++++++++ .../{ops.py => ops/_core.py} | 596 ++---------------- src/boring_semantic_layer/ops/_values.py | 166 +++++ 4 files changed, 665 insertions(+), 533 deletions(-) create mode 100644 src/boring_semantic_layer/ops/__init__.py create mode 100644 src/boring_semantic_layer/ops/_column_extraction.py rename src/boring_semantic_layer/{ops.py => ops/_core.py} (90%) create mode 100644 src/boring_semantic_layer/ops/_values.py diff --git a/src/boring_semantic_layer/ops/__init__.py b/src/boring_semantic_layer/ops/__init__.py new file mode 100644 index 0000000..11b4158 --- /dev/null +++ b/src/boring_semantic_layer/ops/__init__.py @@ -0,0 +1,67 @@ +"""Semantic layer operations. + +This package was split out of a single ``ops.py``. ``__init__`` re-exports +the names that callers import from ``boring_semantic_layer.ops`` so existing +imports keep working unchanged. +""" + +from ._column_extraction import ( + ColumnExtractionResult, + ColumnTracker, + JoinColumnExtractionResult, + TableColumnRequirements, + _extract_columns_from_callable, + _extract_join_key_columns, + _extract_requirements_from_keys, + _extract_requirements_from_measures, + _make_tracking_proxy, + _parse_prefixed_field, +) +from ._core import ( + # Public value objects + Dimension, + Measure, + # Public Op classes + SemanticAggregateOp, + SemanticFilterOp, + SemanticGroupByOp, + SemanticIndexOp, + SemanticJoinOp, + SemanticLimitOp, + SemanticMutateOp, + SemanticOrderByOp, + SemanticProjectOp, + SemanticTableOp, + SemanticUnnestOp, + # Private helpers used by other modules in this package + _CallableWrapper, + _classify_measure, + _collect_measure_refs, + _ensure_xorq_table, + _find_all_root_models, + _get_field_dict, + _get_merged_fields, + _is_deferred, + _make_schema, + _merge_fields_with_prefixing, + _normalize_join_predicate, + _normalize_to_name, + _rebind_to_backend, + _rebind_to_canonical_backend, +) + +__all__ = [ + "Dimension", + "Measure", + "SemanticAggregateOp", + "SemanticFilterOp", + "SemanticGroupByOp", + "SemanticIndexOp", + "SemanticJoinOp", + "SemanticLimitOp", + "SemanticMutateOp", + "SemanticOrderByOp", + "SemanticProjectOp", + "SemanticTableOp", + "SemanticUnnestOp", +] diff --git a/src/boring_semantic_layer/ops/_column_extraction.py b/src/boring_semantic_layer/ops/_column_extraction.py new file mode 100644 index 0000000..879aaeb --- /dev/null +++ b/src/boring_semantic_layer/ops/_column_extraction.py @@ -0,0 +1,369 @@ +"""Column tracking & per-table column-requirement extraction. + +Used by projection pushdown to figure out which columns each leaf table +must produce. Split from ``_core.py`` to keep the relational ops module +focused on Op definitions. +""" + +from __future__ import annotations + +import logging +from collections.abc import Callable, Iterable, Mapping, Sequence +from typing import Any + +from attrs import field, frozen +from ibis.expr import operations as ibis_ops +from ibis.expr import types as ir + +from .._xorq import FrozenDict +from ..graph_utils import walk_nodes + +logger = logging.getLogger(__name__) + + +# ============================================================================== +# Column Tracking for Projection Pushdown +# ============================================================================== + + +@frozen +class ColumnTracker: + """Immutable tracker for column references during expression evaluation. + + Uses frozenset for tracked columns. New columns are added by creating + new tracker instances with updated sets. + """ + + columns: frozenset[str] = field(factory=frozenset, converter=frozenset) + + def with_column(self, col_name: str) -> ColumnTracker: + """Return new tracker with additional column.""" + return ColumnTracker(columns=self.columns | {col_name}) + + def merge(self, other: ColumnTracker) -> ColumnTracker: + """Return new tracker with merged columns.""" + return ColumnTracker(columns=self.columns | other.columns) + + +@frozen +class ColumnExtractionResult: + """Result of column extraction with error handling. + + Separates successful extraction from error cases. + """ + + columns: frozenset[str] = field(factory=frozenset, converter=frozenset) + extraction_failed: bool = False + error_type: type[Exception] | None = None + + @classmethod + def success(cls, columns: set[str] | frozenset[str]) -> ColumnExtractionResult: + """Create successful result.""" + return cls(columns=frozenset(columns), extraction_failed=False) + + @classmethod + def failure(cls, error: Exception) -> ColumnExtractionResult: + """Create failure result with error information.""" + return cls( + columns=frozenset(), + extraction_failed=True, + error_type=type(error), + ) + + def is_success(self) -> bool: + """Check if extraction succeeded.""" + return not self.extraction_failed + + +@frozen +class JoinColumnExtractionResult: + """Result of join column extraction for both tables.""" + + left_columns: frozenset[str] = field(factory=frozenset, converter=frozenset) + right_columns: frozenset[str] = field(factory=frozenset, converter=frozenset) + extraction_failed: bool = False + error_type: type[Exception] | None = None + + @classmethod + def success( + cls, + left: set[str] | frozenset[str], + right: set[str] | frozenset[str], + ) -> JoinColumnExtractionResult: + """Create successful result.""" + return cls( + left_columns=frozenset(left), + right_columns=frozenset(right), + extraction_failed=False, + ) + + @classmethod + def failure(cls, error: Exception) -> JoinColumnExtractionResult: + """Create failure result with error information.""" + return cls( + left_columns=frozenset(), + right_columns=frozenset(), + extraction_failed=True, + error_type=type(error), + ) + + def is_success(self) -> bool: + """Check if extraction succeeded.""" + return not self.extraction_failed + + +def _make_tracking_proxy( + table: ir.Table, + on_access: Callable[[str], None], +) -> Any: + """Create tracking proxy with custom access handler. + + Composable factory that enables different tracking strategies + via the on_access callback. + """ + + class _TrackingProxy: + """Proxy that tracks attribute and item access.""" + + def __init__(self, inner_table: ir.Table, access_handler: Callable[[str], None]): + object.__setattr__(self, "_table", inner_table) + object.__setattr__(self, "_on_access", access_handler) + + def __getattr__(self, name: str): + if name.startswith("_"): + return getattr(self._table, name) + self._on_access(name) + return getattr(self._table, name) + + def __getitem__(self, name: str): + self._on_access(name) + return self._table[name] + + return _TrackingProxy(table, on_access) + + +def _extract_columns_from_callable( + fn: Any, + table: ir.Table, +) -> ColumnExtractionResult: + """Extract column names referenced by a callable. + + Uses immutable tracking and returns structured result. + """ + if not callable(fn): + return ColumnExtractionResult.success(frozenset()) + + tracker_ref = [ColumnTracker()] + + def on_column_access(col_name: str) -> None: + tracker_ref[0] = tracker_ref[0].with_column(col_name) + + try: + tracking_proxy = _make_tracking_proxy(table, on_column_access) + fn(tracking_proxy) + return ColumnExtractionResult.success(tracker_ref[0].columns) + + except Exception as e: + return ColumnExtractionResult.failure(e) + + +def _extract_join_key_columns( + on: Callable[[Any, Any], ir.BooleanValue], + left_table: ir.Table, + right_table: ir.Table, +) -> JoinColumnExtractionResult: + left_tracker_ref = [ColumnTracker()] + right_tracker_ref = [ColumnTracker()] + + def on_left_access(col_name: str) -> None: + left_tracker_ref[0] = left_tracker_ref[0].with_column(col_name) + + def on_right_access(col_name: str) -> None: + right_tracker_ref[0] = right_tracker_ref[0].with_column(col_name) + + try: + left_proxy = _make_tracking_proxy(left_table, on_left_access) + right_proxy = _make_tracking_proxy(right_table, on_right_access) + on(left_proxy, right_proxy) + + return JoinColumnExtractionResult.success( + left_tracker_ref[0].columns, + right_tracker_ref[0].columns, + ) + + except Exception as e: + return JoinColumnExtractionResult.failure(e) + + +# ============================================================================== +# Table Column Requirements +# ============================================================================== + + +@frozen +class TableColumnRequirements: + """Immutable representation of column requirements per table. + + Maps table names to sets of required column names. + """ + + requirements: FrozenDict[str, frozenset[str]] = field( + factory=lambda: FrozenDict({}), + converter=lambda d: FrozenDict( + {k: frozenset(v) if not isinstance(v, frozenset) else v for k, v in d.items()}, + ), + ) + + def with_column(self, table_name: str, col_name: str) -> TableColumnRequirements: + """Return new requirements with additional column for table.""" + current_cols = self.requirements.get(table_name, frozenset()) + updated_cols = current_cols | {col_name} + + return TableColumnRequirements( + requirements=dict(self.requirements) | {table_name: updated_cols}, + ) + + def with_columns( + self, + table_name: str, + col_names: Iterable[str], + ) -> TableColumnRequirements: + """Return new requirements with multiple columns for table.""" + current_cols = self.requirements.get(table_name, frozenset()) + updated_cols = current_cols | frozenset(col_names) + + return TableColumnRequirements( + requirements=dict(self.requirements) | {table_name: updated_cols}, + ) + + def merge(self, other: TableColumnRequirements) -> TableColumnRequirements: + """Merge requirements from another instance.""" + merged_dict = dict(self.requirements) + + for table, cols in other.requirements.items(): + if table in merged_dict: + merged_dict[table] = merged_dict[table] | cols + else: + merged_dict[table] = cols + + return TableColumnRequirements(requirements=merged_dict) + + def to_dict(self) -> dict[str, set[str]]: + """Convert to mutable dict for API compatibility.""" + return {table: set(cols) for table, cols in self.requirements.items()} + + +def _parse_prefixed_field(field_name: str) -> tuple[str | None, str]: + """Parse potentially prefixed field name. + + Args: + field_name: Field name, possibly prefixed (e.g., "table.column") + + Returns: + Tuple of (table_name or None, column_name) + """ + if "." in field_name: + table, col = field_name.split(".", 1) + return (table, col) + return (None, field_name) + + +def _extract_requirements_from_keys( + keys: Iterable[str], + merged_dimensions: Mapping[str, Any], + all_roots: Sequence[Any], + table: ir.Table, +) -> TableColumnRequirements: + """Extract column requirements from group-by keys using graph traversal.""" + requirements = TableColumnRequirements() + + for key in keys: + table_name, col_name = _parse_prefixed_field(key) + + if table_name: + # Prefixed: we know the table + requirements = requirements.with_column(table_name, col_name) + else: + # Unprefixed: resolve dimension or use conservative fallback + if key in merged_dimensions: + dim_fn = merged_dimensions[key] + + try: + # Evaluate the dimension to get an Ibis expression + dim_expr = dim_fn(table) + + # Walk the expression graph to find all Field nodes (column references) + field_names = {node.name for node in walk_nodes(ibis_ops.Field, dim_expr)} + + # Filter to only actual columns in the table schema + actual_cols = {col for col in field_names if col in table.columns} + + if actual_cols: + for root in all_roots: + if root.name: + requirements = requirements.with_columns(root.name, actual_cols) + else: + # Fallback: assume key name is column name + for root in all_roots: + if root.name: + requirements = requirements.with_column(root.name, key) + except Exception: + logger.debug( + "dimension graph traversal failed for %r; " + "treating key name as column name", + key, + exc_info=True, + ) + for root in all_roots: + if root.name: + requirements = requirements.with_column(root.name, key) + else: + # Raw column + for root in all_roots: + if root.name: + requirements = requirements.with_column(root.name, key) + + return requirements + + +def _extract_requirements_from_measures( + aggs: Mapping[str, Callable], + all_roots: Sequence[Any], + table: ir.Table, +) -> TableColumnRequirements: + """Extract column requirements from measure aggregations using graph traversal.""" + from ._core import _unwrap + + requirements = TableColumnRequirements() + + for measure_name, measure_fn in aggs.items(): + fn = _unwrap(measure_fn) + + try: + # Evaluate the measure to get an Ibis expression + measure_expr = fn(table) + + # Walk the expression graph to find all Field nodes (column references) + field_names = {node.name for node in walk_nodes(ibis_ops.Field, measure_expr)} + + # Filter to only actual columns in the table schema + actual_cols = {col for col in field_names if col in table.columns} + + if actual_cols: + for root in all_roots: + if root.name: + requirements = requirements.with_columns(root.name, actual_cols) + except Exception: + logger.debug( + "measure graph traversal failed for %r; " + "falling back to name-based column inference", + measure_name, + exc_info=True, + ) + # Conservative fallback: if measure name looks like a column, include it + if measure_name.isidentifier(): + for root in all_roots: + if root.name: + requirements = requirements.with_column(root.name, measure_name) + + return requirements diff --git a/src/boring_semantic_layer/ops.py b/src/boring_semantic_layer/ops/_core.py similarity index 90% rename from src/boring_semantic_layer/ops.py rename to src/boring_semantic_layer/ops/_core.py index 695a2a4..f886139 100644 --- a/src/boring_semantic_layer/ops.py +++ b/src/boring_semantic_layer/ops/_core.py @@ -16,7 +16,7 @@ from ibis.expr.operations.relations import Field, Relation from ibis.expr.schema import Schema -from ._xorq import ( +from .._xorq import ( FrozenDict, FrozenOrderedDict, Schema as XorqSchema, @@ -44,10 +44,10 @@ def _reductions_for_expr(expr): from returns.result import Success, safe from toolz import curry -from . import projection_utils -from .compile_all import compile_grouped_with_all -from .graph_utils import walk_nodes -from .measure_scope import ( +from .. import projection_utils +from ..compile_all import compile_grouped_with_all +from ..graph_utils import walk_nodes +from ..measure_scope import ( AggregationExpr, AllOf, BinOp, @@ -56,7 +56,7 @@ def _reductions_for_expr(expr): MeasureScope, MethodCall, ) -from .nested_access import NestedAccessMarker +from ..nested_access import NestedAccessMarker logger = logging.getLogger(__name__) @@ -91,11 +91,6 @@ def __getattr__(self, name): return getattr(self._table, mapped) -def _is_deferred(expr) -> bool: - # Duck-type check: works for both ibis and xorq Deferred objects - return hasattr(expr, "_resolver") and hasattr(expr, "resolve") - - def _normalize_to_name(arg: str | Deferred) -> str: """Convert a string or simple ``_.name`` Deferred to a plain string name. @@ -191,7 +186,7 @@ def _compound_predicate(left, right): if TYPE_CHECKING: - from .expr import ( + from ..expr import ( SemanticFilter, SemanticGroupBy, SemanticLimit, @@ -208,7 +203,7 @@ def _patch_xorq_sortkey_compat(): """ from ibis.expr.operations.sortkeys import SortKey as IbisSortKey - from ._xorq import SortKey as XorqSortKey, map_ibis + from .._xorq import SortKey as XorqSortKey, map_ibis if IbisSortKey in map_ibis.registry: return # already patched @@ -239,7 +234,7 @@ def _ensure_xorq_table(table): _patch_xorq_sortkey_compat() if "xorq.vendor.ibis" not in type(table).__module__: try: - from ._xorq import from_ibis + from .._xorq import from_ibis return from_ibis(table) except Exception: @@ -262,7 +257,7 @@ def _rebind_to_backend(expr, target_backend): reason; callers must pass a xorq-vendored ``target_backend``. """ try: - from ._xorq import relations as xorq_rel + from .._xorq import relations as xorq_rel except ImportError: return expr @@ -295,7 +290,7 @@ def _rebind_to_canonical_backend(expr): No-op on plain ibis expressions (not xorq-vendored). """ try: - from ._xorq import relations as xorq_rel, walk_nodes + from .._xorq import relations as xorq_rel, walk_nodes except ImportError: return expr @@ -318,7 +313,7 @@ def _to_untagged(source: Any) -> ir.Table: def _semantic_table(*args, **kwargs) -> SemanticTable: - from .expr import SemanticModel + from ..expr import SemanticModel return SemanticModel(*args, **kwargs) @@ -482,7 +477,7 @@ def _resolve_expr(expr: Deferred | Callable | Any, scope: ir.Table) -> ir.Value: scope_is_xorq = "xorq.vendor.ibis" in scope_module if result_is_regular_ibis and scope_is_xorq: - from ._xorq import from_ibis + from .._xorq import from_ibis result = from_ibis(result) @@ -642,7 +637,7 @@ def _infer_unnest(fn: Callable, table: Any) -> tuple[str, ...]: to_semantic_table(tbl).unnest("hits").with_measures(...) -> ("hits",) unnested.unnest("product").with_measures(...) -> ("product",) """ - from .expr import SemanticUnnest + from ..expr import SemanticUnnest if isinstance(table, SemanticUnnest): op = table.op() @@ -854,7 +849,7 @@ def _classify_measure( fn_or_expr: Any, scope: Any, measure_name: str | None = None ) -> tuple[str, Any]: """Classify measure as 'calc' or 'base' with appropriate handling.""" - from .measure_scope import validate_calc_ast + from ..measure_scope import validate_calc_ast expr, description, requires_unnest, metadata = _extract_measure_metadata(fn_or_expr) @@ -933,149 +928,15 @@ def _format_column_error(e: AttributeError, table: ir.Table) -> str: return " ".join(parts) -class _DimPrefixProxy: - """Resolves ``proxy.column`` to ``dims["prefix.column"](table)``.""" - - __slots__ = ("_tbl", "_dims", "_prefix") - - def __init__(self, tbl, dims: dict, prefix: str): - object.__setattr__(self, "_tbl", tbl) - object.__setattr__(self, "_dims", dims) - object.__setattr__(self, "_prefix", prefix) - - def __getattr__(self, name: str): - full_name = f"{self._prefix}.{name}" - if full_name in self._dims: - return self._dims[full_name](self._tbl) - raise AttributeError( - f"No dimension '{full_name}' found. " - f"Available dimensions with prefix '{self._prefix}.': " - f"{[k for k in self._dims if k.startswith(self._prefix + '.')]}" - ) - - -class _DimensionTableProxy: - """Proxy that wraps an ibis table to support model-prefix navigation. - - Allows dimension lambdas like ``lambda t: t.flights.carrier`` to work on - joined tables by resolving ``t.flights.carrier`` through the merged - dimension map (``dims["flights.carrier"](table)``). - """ - - __slots__ = ("_tbl", "_dims") - - def __init__(self, tbl, dims: dict): - object.__setattr__(self, "_tbl", tbl) - object.__setattr__(self, "_dims", dims) - - def __getattr__(self, name: str): - prefix = f"{name}." - if any(k.startswith(prefix) for k in self._dims): - return _DimPrefixProxy(self._tbl, self._dims, name) - return getattr(self._tbl, name) - - def __getitem__(self, name: str): - if name in self._dims: - return self._dims[name](self._tbl) - return self._tbl[name] - - @property - def columns(self): - return self._tbl.columns - - -@frozen(kw_only=True, slots=True) -class Dimension: - expr: Callable[[ir.Table], ir.Value] | Deferred - description: str | None = None - is_entity: bool = False - is_time_dimension: bool = False - is_event_timestamp: bool = False - smallest_time_grain: str | None = None - derived_dimensions: tuple[str, ...] = () - metadata: Mapping[str, Any] = field(factory=dict, eq=False, hash=False) - - def __call__(self, table: ir.Table, _dims: dict | None = None) -> ir.Value: - try: - return self.expr.resolve(table) if _is_deferred(self.expr) else self.expr(table) - except AttributeError as e: - # Retry with a prefix-aware proxy for joined tables where - # model prefixes are used (e.g., lambda t: t.flights.carrier) - if _dims and not _is_deferred(self.expr) and callable(self.expr): - try: - proxy = _DimensionTableProxy(table, _dims) - return self.expr(proxy) - except AttributeError as proxy_err: - # Preserve explicit prefix-proxy errors (e.g. missing - # "model.field") to avoid silent fallback to unprefixed - # columns, but keep normal missing-column errors on the - # original table so they get the helpful formatter below. - if str(proxy_err).startswith("No dimension '"): - raise - except Exception: - pass - # Provide helpful error for missing columns - if "'Table' object has no attribute" in str( - e - ) or "'Join' object has no attribute" in str(e): - raise AttributeError(_format_column_error(e, table)) from e - raise - - def to_json(self) -> Mapping[str, Any]: - base = {"description": self.description} - if self.is_entity: - base["is_entity"] = True - if self.is_event_timestamp: - base["is_event_timestamp"] = True - if self.is_time_dimension: - base["smallest_time_grain"] = self.smallest_time_grain - if self.derived_dimensions: - base["derived_dimensions"] = list(self.derived_dimensions) - if self.metadata: - base.update(self.metadata) - return base - - def __hash__(self) -> int: - return hash( - ( - self.description, - self.is_entity, - self.is_event_timestamp, - self.is_time_dimension, - self.smallest_time_grain, - self.derived_dimensions, - ), - ) - - -@frozen(kw_only=True, slots=True) -class Measure: - expr: Callable[[ir.Table], ir.Value] | Deferred - description: str | None = None - requires_unnest: tuple[str, ...] = () # Internal: Arrays that must be unnested - original_expr: Any = field(default=None, eq=False, hash=False) - metadata: Mapping[str, Any] = field(factory=dict, eq=False, hash=False) - - def __call__(self, table: ir.Table) -> ir.Value: - return self.expr.resolve(table) if _is_deferred(self.expr) else self.expr(table) - - @property - def locality(self) -> str | None: - """Derive locality from requires_unnest (most nested level).""" - return self.requires_unnest[-1] if self.requires_unnest else None +from ._values import ( # noqa: E402 + Dimension, + Measure, + _DimPrefixProxy, + _DimensionTableProxy, + _is_deferred, +) - def to_json(self) -> Mapping[str, Any]: - base = {"description": self.description} - if self.locality: - base["locality"] = self.locality - if self.requires_unnest: - base["requires_unnest"] = list(self.requires_unnest) - if self.metadata: - base.update(self.metadata) - return base - def __hash__(self) -> int: - return hash((self.description, self.requires_unnest)) class SemanticTableOp(Relation): @@ -1143,7 +1004,7 @@ def values(self) -> FrozenOrderedDict[str, Any]: # ``compile_grouped_with_all`` so calc measures with inline aggregations # (e.g. ``AllOf(AggregationExpr)``) round-trip through type inference. if calc_measures: - from .compile_all import _get_ibis_module, infer_calc_dtype + from ..compile_all import _get_ibis_module, infer_calc_dtype measure_schema = { name: base_values[name].dtype for name in measures if name in base_values @@ -1203,7 +1064,7 @@ def get_calculated_measures(self) -> Mapping[str, Any]: return self.calc_measures def get_graph(self) -> dict[str, dict[str, Any]]: - from .graph_utils import build_dependency_graph + from ..graph_utils import build_dependency_graph return build_dependency_graph( self.get_dimensions(), @@ -1263,7 +1124,7 @@ def schema(self) -> Schema: return self.source.schema def to_untagged(self): - from .convert import _Resolver + from ..convert import _Resolver all_roots = _find_all_root_models(self.source) base_tbl = _to_untagged(self.source) @@ -2440,7 +2301,7 @@ def _to_untagged_with_preagg( # Apply collected filters to the full joined table so that # dimension bridges only include rows surviving the filter. if filters: - from .convert import _Resolver + from ..convert import _Resolver for pred in filters: pred_fn = _unwrap(pred) @@ -2864,7 +2725,7 @@ def strip_deferred(node): # Apply filters if filters: - from .convert import _Resolver + from ..convert import _Resolver for pred in filters: pred_fn = _unwrap(pred) @@ -2908,7 +2769,7 @@ def strip_deferred(node): # Handle calculated measures if plan.calc_specs: - from .compile_all import compile_calc_measures + from ..compile_all import compile_calc_measures result = compile_calc_measures(result, plan.calc_specs) @@ -2985,7 +2846,7 @@ def _join_preagg_with_dim_bridge( ``decomposed_means`` and ``reagg_ops`` are tuples of (key, value) pairs. """ - from .compile_all import _join_tables + from ..compile_all import _join_tables reagg_map = dict(reagg_ops) # Include decomposed auxiliary columns in measure names @@ -3069,7 +2930,7 @@ def _build_minimal_dim_bridge( ``decomposed_means`` and ``reagg_ops`` are tuples of (key, value) pairs. """ - from .compile_all import _join_tables + from ..compile_all import _join_tables reagg_map = dict(reagg_ops) aux_cols = frozenset(c for _, (sc, cc) in decomposed_means for c in (sc, cc)) @@ -3133,7 +2994,7 @@ def _bridge_one_preagg(pt): @staticmethod def _apply_calc_specs(result, plan, tbl): """Apply calculated measure specs (ratios, percent-of-total, etc.).""" - from .compile_all import _collect_all_refs, _compile_formula + from ..compile_all import _collect_all_refs, _compile_formula needed_totals: set[str] = set() for ast in plan.calc_specs.values(): @@ -3429,7 +3290,7 @@ def query( time_range: dict[str, str] | None = None, having: list | None = None, ): - from .query import query as build_query + from ..query import query as build_query return build_query( semantic_table=self, @@ -3479,12 +3340,12 @@ def with_measures(self, **meas) -> SemanticTable: ) def group_by(self, *keys: str) -> SemanticGroupBy: - from .expr import SemanticGroupBy + from ..expr import SemanticGroupBy return SemanticGroupBy(source=self, keys=keys) def filter(self, predicate: Callable) -> SemanticFilter: - from .expr import SemanticFilter + from ..expr import SemanticFilter return SemanticFilter(source=self, predicate=predicate) @@ -3495,7 +3356,7 @@ def join_one( how: str = "left", ): """Join with one-to-one relationship semantics (left outer join).""" - from .expr import SemanticJoin + from ..expr import SemanticJoin return SemanticJoin( left=self, @@ -3512,7 +3373,7 @@ def join_many( how: str = "left", ): """Join with one-to-many relationship semantics.""" - from .expr import SemanticJoin + from ..expr import SemanticJoin return SemanticJoin( left=self, @@ -3524,7 +3385,7 @@ def join_many( def join_cross(self, other: SemanticTable): """Cross join (Cartesian product) with another semantic model.""" - from .expr import SemanticJoin + from ..expr import SemanticJoin return SemanticJoin( left=self, @@ -3984,7 +3845,7 @@ def to_untagged(self, parent_requirements: dict[str, set[str]] | None = None): Returns: Ibis join expression (potentially simplified). """ - from .convert import _Resolver + from ..convert import _Resolver augmented_requirements = self._augment_parent_requirements_for_pruning(parent_requirements) @@ -4076,7 +3937,7 @@ def _rebind_join_backends(left_tbl, right_tbl): returning the inputs unchanged so ibis executes the join natively. """ try: - from ._xorq import relations as xorq_rel, walk_nodes + from .._xorq import relations as xorq_rel, walk_nodes except ImportError: return left_tbl, right_tbl @@ -4262,7 +4123,7 @@ def _get_weight_expr( all_roots: list, is_string: bool, ) -> Any: - from ._xorq import api as xo + from .._xorq import api as xo if not by_measure: return xo._.count() @@ -4281,7 +4142,7 @@ def _build_string_index_fragment( type_str: str, weight_expr: Any, ) -> Any: - from ._xorq import api as xo + from .._xorq import api as xo return ( base_tbl.group_by(field_expr.name("value")) @@ -4304,7 +4165,7 @@ def _build_numeric_index_fragment( type_str: str, weight_expr: Any, ) -> Any: - from ._xorq import api as xo + from .._xorq import api as xo return ( base_tbl.select(field_expr.name("value")) @@ -4402,7 +4263,7 @@ def __repr__(self) -> str: @property def values(self) -> FrozenOrderedDict[str, Any]: - from ._xorq import api as xo + from .._xorq import api as xo return FrozenOrderedDict( { @@ -4450,7 +4311,7 @@ def to_untagged(self): ) if not fields_to_index: - from ._xorq import api as xo + from .._xorq import api as xo return xo.memtable( { @@ -4501,17 +4362,17 @@ def build_fragment(field_name: str) -> Any: return reduce(lambda acc, frag: acc.union(frag), fragments[1:], fragments[0]) def filter(self, predicate: Callable) -> SemanticFilter: - from .expr import SemanticFilter + from ..expr import SemanticFilter return SemanticFilter(source=self, predicate=predicate) def order_by(self, *keys: str | ir.Value | Callable) -> SemanticOrderBy: - from .expr import SemanticOrderBy + from ..expr import SemanticOrderBy return SemanticOrderBy(source=self, keys=keys) def limit(self, n: int, offset: int = 0) -> SemanticLimit: - from .expr import SemanticLimit + from ..expr import SemanticLimit return SemanticLimit(source=self, n=n, offset=offset) @@ -4781,7 +4642,7 @@ def _build_column_rename_map( # Build column index using graph_utils (returns Result) from returns.result import Failure - from .graph_utils import build_column_index_from_roots, extract_column_from_dimension + from ..graph_utils import build_column_index_from_roots, extract_column_from_dimension column_index_result = build_column_index_from_roots(all_roots) if isinstance(column_index_result, Failure): @@ -4931,7 +4792,7 @@ def _merge_fields_with_prefixing( if all_roots: sample_fields = field_accessor(all_roots[0]) if sample_fields: - from .measure_scope import AllOf, BinOp, MeasureRef, MethodCall + from ..measure_scope import AllOf, BinOp, MeasureRef, MethodCall first_val = next(iter(sample_fields.values()), None) is_calc_measures = isinstance( @@ -4983,348 +4844,17 @@ def _merge_fields_with_prefixing( return FrozenDict(merged_fields) - -# ============================================================================== -# Column Tracking for Projection Pushdown -# ============================================================================== - - -@frozen -class ColumnTracker: - """Immutable tracker for column references during expression evaluation. - - Uses frozenset for tracked columns. New columns are added by creating - new tracker instances with updated sets. - """ - - columns: frozenset[str] = field(factory=frozenset, converter=frozenset) - - def with_column(self, col_name: str) -> ColumnTracker: - """Return new tracker with additional column.""" - return ColumnTracker(columns=self.columns | {col_name}) - - def merge(self, other: ColumnTracker) -> ColumnTracker: - """Return new tracker with merged columns.""" - return ColumnTracker(columns=self.columns | other.columns) - - -@frozen -class ColumnExtractionResult: - """Result of column extraction with error handling. - - Separates successful extraction from error cases. - """ - - columns: frozenset[str] = field(factory=frozenset, converter=frozenset) - extraction_failed: bool = False - error_type: type[Exception] | None = None - - @classmethod - def success(cls, columns: set[str] | frozenset[str]) -> ColumnExtractionResult: - """Create successful result.""" - return cls(columns=frozenset(columns), extraction_failed=False) - - @classmethod - def failure(cls, error: Exception) -> ColumnExtractionResult: - """Create failure result with error information.""" - return cls( - columns=frozenset(), - extraction_failed=True, - error_type=type(error), - ) - - def is_success(self) -> bool: - """Check if extraction succeeded.""" - return not self.extraction_failed - - -@frozen -class JoinColumnExtractionResult: - """Result of join column extraction for both tables.""" - - left_columns: frozenset[str] = field(factory=frozenset, converter=frozenset) - right_columns: frozenset[str] = field(factory=frozenset, converter=frozenset) - extraction_failed: bool = False - error_type: type[Exception] | None = None - - @classmethod - def success( - cls, - left: set[str] | frozenset[str], - right: set[str] | frozenset[str], - ) -> JoinColumnExtractionResult: - """Create successful result.""" - return cls( - left_columns=frozenset(left), - right_columns=frozenset(right), - extraction_failed=False, - ) - - @classmethod - def failure(cls, error: Exception) -> JoinColumnExtractionResult: - """Create failure result with error information.""" - return cls( - left_columns=frozenset(), - right_columns=frozenset(), - extraction_failed=True, - error_type=type(error), - ) - - def is_success(self) -> bool: - """Check if extraction succeeded.""" - return not self.extraction_failed - - -def _make_tracking_proxy( - table: ir.Table, - on_access: Callable[[str], None], -) -> Any: - """Create tracking proxy with custom access handler. - - Composable factory that enables different tracking strategies - via the on_access callback. - """ - - class _TrackingProxy: - """Proxy that tracks attribute and item access.""" - - def __init__(self, inner_table: ir.Table, access_handler: Callable[[str], None]): - object.__setattr__(self, "_table", inner_table) - object.__setattr__(self, "_on_access", access_handler) - - def __getattr__(self, name: str): - if name.startswith("_"): - return getattr(self._table, name) - self._on_access(name) - return getattr(self._table, name) - - def __getitem__(self, name: str): - self._on_access(name) - return self._table[name] - - return _TrackingProxy(table, on_access) - - -def _extract_columns_from_callable( - fn: Any, - table: ir.Table, -) -> ColumnExtractionResult: - """Extract column names referenced by a callable. - - Uses immutable tracking and returns structured result. - """ - if not callable(fn): - return ColumnExtractionResult.success(frozenset()) - - tracker_ref = [ColumnTracker()] - - def on_column_access(col_name: str) -> None: - tracker_ref[0] = tracker_ref[0].with_column(col_name) - - try: - tracking_proxy = _make_tracking_proxy(table, on_column_access) - fn(tracking_proxy) - return ColumnExtractionResult.success(tracker_ref[0].columns) - - except Exception as e: - return ColumnExtractionResult.failure(e) - - -def _extract_join_key_columns( - on: Callable[[Any, Any], ir.BooleanValue], - left_table: ir.Table, - right_table: ir.Table, -) -> JoinColumnExtractionResult: - left_tracker_ref = [ColumnTracker()] - right_tracker_ref = [ColumnTracker()] - - def on_left_access(col_name: str) -> None: - left_tracker_ref[0] = left_tracker_ref[0].with_column(col_name) - - def on_right_access(col_name: str) -> None: - right_tracker_ref[0] = right_tracker_ref[0].with_column(col_name) - - try: - left_proxy = _make_tracking_proxy(left_table, on_left_access) - right_proxy = _make_tracking_proxy(right_table, on_right_access) - on(left_proxy, right_proxy) - - return JoinColumnExtractionResult.success( - left_tracker_ref[0].columns, - right_tracker_ref[0].columns, - ) - - except Exception as e: - return JoinColumnExtractionResult.failure(e) - - -# ============================================================================== -# Table Column Requirements -# ============================================================================== - - -@frozen -class TableColumnRequirements: - """Immutable representation of column requirements per table. - - Maps table names to sets of required column names. - """ - - requirements: FrozenDict[str, frozenset[str]] = field( - factory=lambda: FrozenDict({}), - converter=lambda d: FrozenDict( - {k: frozenset(v) if not isinstance(v, frozenset) else v for k, v in d.items()}, - ), - ) - - def with_column(self, table_name: str, col_name: str) -> TableColumnRequirements: - """Return new requirements with additional column for table.""" - current_cols = self.requirements.get(table_name, frozenset()) - updated_cols = current_cols | {col_name} - - return TableColumnRequirements( - requirements=dict(self.requirements) | {table_name: updated_cols}, - ) - - def with_columns( - self, - table_name: str, - col_names: Iterable[str], - ) -> TableColumnRequirements: - """Return new requirements with multiple columns for table.""" - current_cols = self.requirements.get(table_name, frozenset()) - updated_cols = current_cols | frozenset(col_names) - - return TableColumnRequirements( - requirements=dict(self.requirements) | {table_name: updated_cols}, - ) - - def merge(self, other: TableColumnRequirements) -> TableColumnRequirements: - """Merge requirements from another instance.""" - merged_dict = dict(self.requirements) - - for table, cols in other.requirements.items(): - if table in merged_dict: - merged_dict[table] = merged_dict[table] | cols - else: - merged_dict[table] = cols - - return TableColumnRequirements(requirements=merged_dict) - - def to_dict(self) -> dict[str, set[str]]: - """Convert to mutable dict for API compatibility.""" - return {table: set(cols) for table, cols in self.requirements.items()} - - -def _parse_prefixed_field(field_name: str) -> tuple[str | None, str]: - """Parse potentially prefixed field name. - - Args: - field_name: Field name, possibly prefixed (e.g., "table.column") - - Returns: - Tuple of (table_name or None, column_name) - """ - if "." in field_name: - table, col = field_name.split(".", 1) - return (table, col) - return (None, field_name) - - -def _extract_requirements_from_keys( - keys: Iterable[str], - merged_dimensions: Mapping[str, Any], - all_roots: Sequence[Any], - table: ir.Table, -) -> TableColumnRequirements: - """Extract column requirements from group-by keys using graph traversal.""" - requirements = TableColumnRequirements() - - for key in keys: - table_name, col_name = _parse_prefixed_field(key) - - if table_name: - # Prefixed: we know the table - requirements = requirements.with_column(table_name, col_name) - else: - # Unprefixed: resolve dimension or use conservative fallback - if key in merged_dimensions: - dim_fn = merged_dimensions[key] - - try: - # Evaluate the dimension to get an Ibis expression - dim_expr = dim_fn(table) - - # Walk the expression graph to find all Field nodes (column references) - field_names = {node.name for node in walk_nodes(ibis_ops.Field, dim_expr)} - - # Filter to only actual columns in the table schema - actual_cols = {col for col in field_names if col in table.columns} - - if actual_cols: - for root in all_roots: - if root.name: - requirements = requirements.with_columns(root.name, actual_cols) - else: - # Fallback: assume key name is column name - for root in all_roots: - if root.name: - requirements = requirements.with_column(root.name, key) - except Exception: - logger.debug( - "dimension graph traversal failed for %r; " - "treating key name as column name", - key, - exc_info=True, - ) - for root in all_roots: - if root.name: - requirements = requirements.with_column(root.name, key) - else: - # Raw column - for root in all_roots: - if root.name: - requirements = requirements.with_column(root.name, key) - - return requirements - - -def _extract_requirements_from_measures( - aggs: Mapping[str, Callable], - all_roots: Sequence[Any], - table: ir.Table, -) -> TableColumnRequirements: - """Extract column requirements from measure aggregations using graph traversal.""" - requirements = TableColumnRequirements() - - for measure_name, measure_fn in aggs.items(): - fn = _unwrap(measure_fn) - - try: - # Evaluate the measure to get an Ibis expression - measure_expr = fn(table) - - # Walk the expression graph to find all Field nodes (column references) - field_names = {node.name for node in walk_nodes(ibis_ops.Field, measure_expr)} - - # Filter to only actual columns in the table schema - actual_cols = {col for col in field_names if col in table.columns} - - if actual_cols: - for root in all_roots: - if root.name: - requirements = requirements.with_columns(root.name, actual_cols) - except Exception: - logger.debug( - "measure graph traversal failed for %r; " - "falling back to name-based column inference", - measure_name, - exc_info=True, - ) - # Conservative fallback: if measure name looks like a column, include it - if measure_name.isidentifier(): - for root in all_roots: - if root.name: - requirements = requirements.with_column(root.name, measure_name) - - return requirements +# Column-extraction classes & helpers were moved to ._column_extraction. +# Re-imported here so internal callers in this module keep working. +from ._column_extraction import ( # noqa: E402 + ColumnExtractionResult, + ColumnTracker, + JoinColumnExtractionResult, + TableColumnRequirements, + _extract_columns_from_callable, + _extract_join_key_columns, + _extract_requirements_from_keys, + _extract_requirements_from_measures, + _make_tracking_proxy, + _parse_prefixed_field, +) diff --git a/src/boring_semantic_layer/ops/_values.py b/src/boring_semantic_layer/ops/_values.py new file mode 100644 index 0000000..b06735a --- /dev/null +++ b/src/boring_semantic_layer/ops/_values.py @@ -0,0 +1,166 @@ +"""Public value objects: ``Dimension`` and ``Measure``. + +Plus the prefix proxies that let dimension lambdas use model-prefix +navigation (``lambda t: t.flights.carrier``) on joined tables. +""" + +from __future__ import annotations + +from collections.abc import Callable, Mapping +from typing import Any + +from attrs import field, frozen +from ibis.common.deferred import Deferred +from ibis.expr import types as ir + + +def _is_deferred(expr) -> bool: + # Duck-type check: works for both ibis and xorq Deferred objects + return hasattr(expr, "_resolver") and hasattr(expr, "resolve") + + +class _DimPrefixProxy: + """Resolves ``proxy.column`` to ``dims["prefix.column"](table)``.""" + + __slots__ = ("_tbl", "_dims", "_prefix") + + def __init__(self, tbl, dims: dict, prefix: str): + object.__setattr__(self, "_tbl", tbl) + object.__setattr__(self, "_dims", dims) + object.__setattr__(self, "_prefix", prefix) + + def __getattr__(self, name: str): + full_name = f"{self._prefix}.{name}" + if full_name in self._dims: + return self._dims[full_name](self._tbl) + raise AttributeError( + f"No dimension '{full_name}' found. " + f"Available dimensions with prefix '{self._prefix}.': " + f"{[k for k in self._dims if k.startswith(self._prefix + '.')]}" + ) + + +class _DimensionTableProxy: + """Proxy that wraps an ibis table to support model-prefix navigation. + + Allows dimension lambdas like ``lambda t: t.flights.carrier`` to work on + joined tables by resolving ``t.flights.carrier`` through the merged + dimension map (``dims["flights.carrier"](table)``). + """ + + __slots__ = ("_tbl", "_dims") + + def __init__(self, tbl, dims: dict): + object.__setattr__(self, "_tbl", tbl) + object.__setattr__(self, "_dims", dims) + + def __getattr__(self, name: str): + prefix = f"{name}." + if any(k.startswith(prefix) for k in self._dims): + return _DimPrefixProxy(self._tbl, self._dims, name) + return getattr(self._tbl, name) + + def __getitem__(self, name: str): + if name in self._dims: + return self._dims[name](self._tbl) + return self._tbl[name] + + @property + def columns(self): + return self._tbl.columns + + +@frozen(kw_only=True, slots=True) +class Dimension: + expr: Callable[[ir.Table], ir.Value] | Deferred + description: str | None = None + is_entity: bool = False + is_time_dimension: bool = False + is_event_timestamp: bool = False + smallest_time_grain: str | None = None + derived_dimensions: tuple[str, ...] = () + metadata: Mapping[str, Any] = field(factory=dict, eq=False, hash=False) + + def __call__(self, table: ir.Table, _dims: dict | None = None) -> ir.Value: + try: + return self.expr.resolve(table) if _is_deferred(self.expr) else self.expr(table) + except AttributeError as e: + # Retry with a prefix-aware proxy for joined tables where + # model prefixes are used (e.g., lambda t: t.flights.carrier) + if _dims and not _is_deferred(self.expr) and callable(self.expr): + try: + proxy = _DimensionTableProxy(table, _dims) + return self.expr(proxy) + except AttributeError as proxy_err: + # Preserve explicit prefix-proxy errors (e.g. missing + # "model.field") to avoid silent fallback to unprefixed + # columns, but keep normal missing-column errors on the + # original table so they get the helpful formatter below. + if str(proxy_err).startswith("No dimension '"): + raise + except Exception: + pass + # Provide helpful error for missing columns + if "'Table' object has no attribute" in str( + e + ) or "'Join' object has no attribute" in str(e): + from ._core import _format_column_error + + raise AttributeError(_format_column_error(e, table)) from e + raise + + def to_json(self) -> Mapping[str, Any]: + base = {"description": self.description} + if self.is_entity: + base["is_entity"] = True + if self.is_event_timestamp: + base["is_event_timestamp"] = True + if self.is_time_dimension: + base["smallest_time_grain"] = self.smallest_time_grain + if self.derived_dimensions: + base["derived_dimensions"] = list(self.derived_dimensions) + if self.metadata: + base.update(self.metadata) + return base + + def __hash__(self) -> int: + return hash( + ( + self.description, + self.is_entity, + self.is_event_timestamp, + self.is_time_dimension, + self.smallest_time_grain, + self.derived_dimensions, + ), + ) + + +@frozen(kw_only=True, slots=True) +class Measure: + expr: Callable[[ir.Table], ir.Value] | Deferred + description: str | None = None + requires_unnest: tuple[str, ...] = () # Internal: Arrays that must be unnested + original_expr: Any = field(default=None, eq=False, hash=False) + metadata: Mapping[str, Any] = field(factory=dict, eq=False, hash=False) + + def __call__(self, table: ir.Table) -> ir.Value: + return self.expr.resolve(table) if _is_deferred(self.expr) else self.expr(table) + + @property + def locality(self) -> str | None: + """Derive locality from requires_unnest (most nested level).""" + return self.requires_unnest[-1] if self.requires_unnest else None + + def to_json(self) -> Mapping[str, Any]: + base = {"description": self.description} + if self.locality: + base["locality"] = self.locality + if self.requires_unnest: + base["requires_unnest"] = list(self.requires_unnest) + if self.metadata: + base.update(self.metadata) + return base + + def __hash__(self) -> int: + return hash((self.description, self.requires_unnest)) From 53578480bc7d28729de08d4870f23ae29bdd6f1f Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Mon, 4 May 2026 07:40:06 -0400 Subject: [PATCH 2/6] refactor(ops): extract callable wrapper, xorq compat, normalize, root models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull four more self-contained chunks out of _core.py: - _callable.py (60 lines): _CallableWrapper / _ensure_wrapped / _infer_unnest. Tiny module, used by serialization & utils externally. - _xorq_compat.py (127 lines): _ensure_xorq_table, _patch_xorq_sortkey_compat, _rebind_to_backend, _rebind_to_canonical_backend. Pure ibis↔xorq bridge logic. - _normalize.py (105 lines): _normalize_to_name and _normalize_join_predicate, the two arg-coercion helpers used by the user-facing API. - _root_models.py (456 lines): root-model traversal (_find_all_root_models et al) and the join-aware field-prefixing / _right column-rename machinery. SemanticTableOp / SemanticJoinOp are imported lazily inside each function to break the import cycle with _core. _core.py is now 4,186 lines (down from 4,621 last commit, 5,330 at start of branch). All extractions are pure moves; _core.py re-imports the names back from each submodule so internal call sites are unchanged. 930 tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/boring_semantic_layer/ops/_callable.py | 60 ++ src/boring_semantic_layer/ops/_core.py | 714 +----------------- src/boring_semantic_layer/ops/_normalize.py | 105 +++ src/boring_semantic_layer/ops/_root_models.py | 483 ++++++++++++ src/boring_semantic_layer/ops/_xorq_compat.py | 127 ++++ 5 files changed, 795 insertions(+), 694 deletions(-) create mode 100644 src/boring_semantic_layer/ops/_callable.py create mode 100644 src/boring_semantic_layer/ops/_normalize.py create mode 100644 src/boring_semantic_layer/ops/_root_models.py create mode 100644 src/boring_semantic_layer/ops/_xorq_compat.py diff --git a/src/boring_semantic_layer/ops/_callable.py b/src/boring_semantic_layer/ops/_callable.py new file mode 100644 index 0000000..41ace49 --- /dev/null +++ b/src/boring_semantic_layer/ops/_callable.py @@ -0,0 +1,60 @@ +"""Hashable callable wrapper used to put lambdas/Deferred into FrozenDict. + +Both raw callables and user-side ``Deferred`` instances are not hashable +in their bare form, but ibis Op classes store dimension/measure exprs in +``FrozenDict``. ``_CallableWrapper`` gives them an identity-based hash so +they can be persisted in op fields. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from attrs import frozen + + +@frozen +class _CallableWrapper: + """Hashable wrapper for Callable and Deferred. + + Both raw callables (lambda) and user Deferred (_.foo) are not hashable + and cannot be stored in FrozenDict. This wrapper provides hashability + using identity-based hashing. + """ + + _fn: Any + + def __call__(self, *args, **kwargs): + return self._fn(*args, **kwargs) + + def __hash__(self): + # should this be dask.base.tokenize()? + return hash(id(self._fn)) + + @property + def unwrap(self): + return self._fn + + +def _ensure_wrapped(fn: Any) -> _CallableWrapper: + """Wrap Callable or Deferred for hashability.""" + return fn if isinstance(fn, _CallableWrapper) else _CallableWrapper(fn) + + +def _infer_unnest(fn: Callable, table: Any) -> tuple[str, ...]: + """Infer required unnest operations from the table. + + Examples: + to_semantic_table(tbl).with_measures(...) -> () # Session level + to_semantic_table(tbl).unnest("hits").with_measures(...) -> ("hits",) + unnested.unnest("product").with_measures(...) -> ("product",) + """ + from ..expr import SemanticUnnest + + if isinstance(table, SemanticUnnest): + op = table.op() + # SemanticUnnestOp always has column attribute + return (op.column,) + + return () diff --git a/src/boring_semantic_layer/ops/_core.py b/src/boring_semantic_layer/ops/_core.py index f886139..39d64e6 100644 --- a/src/boring_semantic_layer/ops/_core.py +++ b/src/boring_semantic_layer/ops/_core.py @@ -91,98 +91,7 @@ def __getattr__(self, name): return getattr(self._table, mapped) -def _normalize_to_name(arg: str | Deferred) -> str: - """Convert a string or simple ``_.name`` Deferred to a plain string name. - - Accepts a plain string (returned as-is) or a Deferred whose resolver is a - simple attribute access on the top-level ``_`` variable (e.g. ``_.origin``). - - Complex expressions like ``_.distance.sum()`` or ``_.a.b`` are rejected - with a ``TypeError``. - """ - if isinstance(arg, str): - return arg - - # Duck-type: works for both ibis and xorq Deferred objects - resolver = getattr(arg, "_resolver", None) - if resolver is None: - raise TypeError( - f"Expected a string name or Deferred expression (_.name), got {type(arg).__name__}" - ) - - obj = getattr(resolver, "obj", None) - - # Try attribute access first (_.name -> Attr resolver with .name) - name_wrapper = getattr(resolver, "name", None) - - # Fall back to getitem access (_["name"] -> Item resolver with .indexer) - if name_wrapper is None: - name_wrapper = getattr(resolver, "indexer", None) - - if name_wrapper is None or obj is None: - raise TypeError( - f"Only simple Deferred expressions like _.name or _['name'] are supported " - f"as positional arguments, got: {arg!r}" - ) - - # Reject chained access like _.a.b (obj would itself have an .obj attr) - if getattr(obj, "obj", None) is not None: - raise TypeError( - f"Only simple Deferred expressions like _.name or _['name'] are supported " - f"as positional arguments, got: {arg!r}" - ) - - # Attr.name / Item.indexer is a Just wrapper; unwrap via .value - raw_name = getattr(name_wrapper, "value", name_wrapper) - if not isinstance(raw_name, str): - raise TypeError(f"Could not extract string name from Deferred expression: {arg!r}") - - return raw_name - - -def _normalize_join_predicate(on): - """Normalize a join predicate to a two-argument callable. - - Accepts: - - ``str`` – equi-join on a column present in both sides - - ``Deferred`` (``_.col``) – same, after extracting the name - - ``list[str | Deferred]`` – compound equi-join on multiple columns - - ``callable`` (non-Deferred) – returned as-is (existing lambda API) - - ``None`` – returned as-is (for cross joins) - """ - if on is None: - return on - - if isinstance(on, str): - name = on - return lambda left, right: getattr(left, name) == getattr(right, name) - - if _is_deferred(on): - name = _normalize_to_name(on) - return lambda left, right: getattr(left, name) == getattr(right, name) - - if isinstance(on, (list, tuple)): - names = [_normalize_to_name(item) for item in on] - if len(names) == 1: - name = names[0] - return lambda left, right: getattr(left, name) == getattr(right, name) - - def _compound_predicate(left, right): - from functools import reduce - from operator import and_ - - preds = [getattr(left, n) == getattr(right, n) for n in names] - return reduce(and_, preds) - - return _compound_predicate - - if callable(on): - return on - - raise TypeError( - f"join `on` must be a string, Deferred (_.col), list of strings/Deferred, " - f"or a callable, got {type(on).__name__}" - ) +from ._normalize import _normalize_join_predicate, _normalize_to_name # noqa: E402, F401 if TYPE_CHECKING: @@ -195,117 +104,12 @@ def _compound_predicate(left, right): ) -def _patch_xorq_sortkey_compat(): - """Register a map_ibis handler so ibis SortKey → xorq SortKey. - - ibis 11 uses ``SortKey.expr``, ibis 12 renamed it to ``SortKey.arg``, - while xorq's vendored ibis keeps ``SortKey.expr``. Handle both. - """ - from ibis.expr.operations.sortkeys import SortKey as IbisSortKey - - from .._xorq import SortKey as XorqSortKey, map_ibis - - if IbisSortKey in map_ibis.registry: - return # already patched - - @map_ibis.register(IbisSortKey) - def _map_sort_key(val, kwargs=None): - # ibis 12 uses .arg, ibis 11 uses .expr - sort_expr = getattr(val, "arg", None) or getattr(val, "expr") - return XorqSortKey( - expr=map_ibis(sort_expr, None), - ascending=val.ascending, - nulls_first=val.nulls_first, - ) - - -def _ensure_xorq_table(table): - """Convert plain ibis Table to xorq-vendored ibis if possible. - - This is the single boundary between user-supplied ibis tables and - BSL's internal xorq representation. ``SemanticModel`` calls it once - at construction so internal code paths can assume xorq tables when - the backend is supported, and a plain ibis fallback otherwise. - - Falls back to returning the plain ibis table when the backend is not - supported by xorq (e.g. Databricks). Idempotent: calling it on a - xorq-vendored table is a cheap no-op. - """ - _patch_xorq_sortkey_compat() - if "xorq.vendor.ibis" not in type(table).__module__: - try: - from .._xorq import from_ibis - - return from_ibis(table) - except Exception: - # Backend isn't supported by xorq's map_ibis registry (e.g. - # Databricks). Fall back so plain-ibis paths can still execute. - logger.debug( - "from_ibis failed for %s; using plain ibis table", - type(table).__module__, - exc_info=True, - ) - return table - return table - - -def _rebind_to_backend(expr, target_backend): - """Rebind every ``DatabaseTable`` op in *expr* to *target_backend*. - - Low-level primitive shared with ``serialization.reconstruct``. - No-op on plain ibis expressions or when xorq is unavailable for any - reason; callers must pass a xorq-vendored ``target_backend``. - """ - try: - from .._xorq import relations as xorq_rel - except ImportError: - return expr - - def _recreate(op, _kwargs, **overrides): - kwargs = dict(zip(op.__argnames__, op.__args__, strict=False)) - if _kwargs: - kwargs.update(_kwargs) - kwargs.update(overrides) - return op.__recreate__(kwargs) - - def replacer(op, _kwargs): - if isinstance(op, xorq_rel.DatabaseTable) and op.source is not target_backend: - return _recreate(op, _kwargs, source=target_backend) - if _kwargs: - return _recreate(op, _kwargs) - return op - - return expr.op().replace(replacer).to_expr() - - -def _rebind_to_canonical_backend(expr): - """Rebind divergent ``DatabaseTable`` backends in *expr* to share one. - - ``from_ibis()`` creates a distinct ``Backend`` per call, so expressions - built by composing separately-converted tables contain multiple - backends. Picking the first ``DatabaseTable``'s source as canonical - and rebinding the rest eliminates "Multiple backends found" errors - at execution time. - - No-op on plain ibis expressions (not xorq-vendored). - """ - try: - from .._xorq import relations as xorq_rel, walk_nodes - except ImportError: - return expr - - try: - db_tables = list(walk_nodes((xorq_rel.DatabaseTable,), expr)) - except Exception: - # walk_nodes can't traverse plain ibis trees; treat as no-op. - logger.debug("walk_nodes failed on plain ibis expr", exc_info=True) - return expr - - canonical = db_tables[0].source if db_tables else None - if canonical is None: - return expr - - return _rebind_to_backend(expr, canonical) +from ._xorq_compat import ( # noqa: E402 + _ensure_xorq_table, + _patch_xorq_sortkey_compat, + _rebind_to_backend, + _rebind_to_canonical_backend, +) def _to_untagged(source: Any) -> ir.Table: @@ -601,50 +405,7 @@ def _classify_dependencies( } -@frozen -class _CallableWrapper: - """Hashable wrapper for Callable and Deferred. - - Both raw callables (lambda) and user Deferred (_.foo) are not hashable - and cannot be stored in FrozenDict. This wrapper provides hashability - using identity-based hashing. - """ - - _fn: Any - - def __call__(self, *args, **kwargs): - return self._fn(*args, **kwargs) - - def __hash__(self): - # should this be dask.base.tokenize()? - return hash(id(self._fn)) - - @property - def unwrap(self): - return self._fn - - -def _ensure_wrapped(fn: Any) -> _CallableWrapper: - """Wrap Callable or Deferred for hashability.""" - return fn if isinstance(fn, _CallableWrapper) else _CallableWrapper(fn) - - -def _infer_unnest(fn: Callable, table: Any) -> tuple[str, ...]: - """Infer required unnest operations from the table. - - Examples: - to_semantic_table(tbl).with_measures(...) -> () # Session level - to_semantic_table(tbl).unnest("hits").with_measures(...) -> ("hits",) - unnested.unnest("product").with_measures(...) -> ("product",) - """ - from ..expr import SemanticUnnest - - if isinstance(table, SemanticUnnest): - op = table.op() - # SemanticUnnestOp always has column attribute - return (op.column,) - - return () +from ._callable import _CallableWrapper, _ensure_wrapped, _infer_unnest # noqa: E402 def _extract_measure_metadata( @@ -4396,453 +4157,18 @@ def pipe(self, func, *args, **kwargs): return func(self, *args, **kwargs) -def _find_root_model(node: Any) -> SemanticTableOp | None: - """Find root SemanticTableOp in the operation tree.""" - cur = node - while cur is not None: - if isinstance(cur, SemanticTableOp): - return cur - parent = getattr(cur, "source", None) - cur = parent - return None - - -def _find_all_root_models(node: Any) -> tuple[SemanticTableOp, ...]: - """Find all root SemanticTableOps in the operation tree (handles joins with multiple roots).""" - if isinstance(node, SemanticTableOp): - return [node] - - roots = [] - - if hasattr(node, "left") and hasattr(node, "right"): - roots.extend(_find_all_root_models(node.left)) - roots.extend(_find_all_root_models(node.right)) - elif hasattr(node, "source") and node.source is not None: - roots.extend(_find_all_root_models(node.source)) - - return roots - - -def _dimension_only_source_table( - keys: tuple[str, ...], - all_roots: Sequence[SemanticTableOp], - filters: tuple, -) -> tuple[SemanticTableOp, list[str], tuple] | None: - """Check if a dimension-only query can be routed to a single source table. - - When all requested dimension keys share a single table prefix and that - prefix maps to a root model whose dimensions cover every key, we can - bypass the join and query the dimension table directly. This ensures - dimension members with no matching fact rows are still returned. - - *filters* are the ``_CallableWrapper`` predicates collected between the - aggregate and the underlying join. Filters whose column references all - belong to the target table are forwarded; if any filter references columns - outside the target table the shortcut is disabled. - - Returns ``(root_op, unprefixed_keys, applicable_filters)`` or ``None``. - """ - if not keys: - return None - - prefixes: set[str] = set() - unprefixed: list[str] = [] - for key in keys: - if "." not in key: - return None # Non-prefixed key — can't determine source - prefix, name = key.split(".", 1) - prefixes.add(prefix) - unprefixed.append(name) - - if len(prefixes) != 1: - return None # Keys span multiple tables - - target_prefix = next(iter(prefixes)) - - for root in all_roots: - if root.name == target_prefix: - root_dims = root.get_dimensions() - if all(k in root_dims for k in unprefixed): - # Validate that every filter only touches columns present - # on the target dimension table. If any filter references - # columns from other tables we cannot use the shortcut. - if filters: - tbl = _to_untagged(root) - tbl_cols = frozenset(tbl.columns) | frozenset(root_dims) - for flt in filters: - fn = _unwrap(flt) if hasattr(flt, "unwrap") else flt - extraction = _extract_columns_from_callable(fn, tbl) - if extraction.extraction_failed: - return None # Can't determine — bail out - if not extraction.columns <= tbl_cols: - return None # References columns outside target - return root, unprefixed, filters - - return None - - -def _build_join_depth_map(node: Any) -> dict[str, int]: - """Map each leaf table name to its actual ibis rname depth. - - ``SemanticJoinOp.to_untagged`` calls ``_join_depth`` to determine the - rname suffix for each join level. ``_join_depth`` counts the number - of ``SemanticJoinOp`` ancestors on the *left* spine. The right child - at depth *d* gets ``rname = _rname_for_depth(d)``. - - For nested subtrees on the right side of a join, ibis applies the - inner subtree's rname independently. So ``aircraft_models`` at inner - depth 1 gets ``_right``, not ``_right3`` even if the outer depth is 3. - - This function mirrors ``_join_depth`` logic: walk down the left spine, - recording the right child's depth at each level. If the right child is - itself a join tree, recurse to get inner depths for its leaves. - """ - depth_map: dict[str, int] = {} - - def _record_leaf(n, depth: int): - """Record a leaf table at the given depth.""" - if isinstance(n, SemanticTableOp): - name = n.name - if name and name not in depth_map: - depth_map[name] = depth - - def _walk_join_spine(n): - """Walk the left spine of a join tree, recording depths.""" - if not isinstance(n, SemanticJoinOp): - # Leftmost leaf: depth 0 (root, never renamed) - _record_leaf(n, 0) - return - - depth = SemanticJoinOp._join_depth(n) - # The right child is joined at this depth - right = n.right - if isinstance(right, SemanticJoinOp): - # Right is a subtree — its leaves get inner depths - inner_map = _build_join_depth_map(right) - for tname, idepth in inner_map.items(): - if tname not in depth_map: - if idepth == 0: - # Leftmost leaf of subtree sits at the outer depth - # (it receives the outer rname suffix if conflicting) - depth_map[tname] = depth - else: - # Inner leaves keep their inner depth (inner rname) - depth_map[tname] = idepth - else: - _record_leaf(right, depth) - - # Recurse down the left spine - _walk_join_spine(n.left) - - _walk_join_spine(node) - return depth_map - - -def _update_measure_refs_in_calc(expr, prefix_map: dict[str, str]): - """ - Recursively update MeasureRef names in a calculated measure expression. - - Args: - expr: A MeasureExpr (MeasureRef, AllOf, BinOp, MethodCall, or literal) - prefix_map: Mapping from old name to new prefixed name - - Returns: - Updated expression with prefixed MeasureRef names - """ - if isinstance(expr, MeasureRef): - # Update the measure reference name if it's in the map - new_name = prefix_map.get(expr.name, expr.name) - return MeasureRef(new_name) - elif isinstance(expr, AllOf): - # Update the inner MeasureRef - updated_ref = _update_measure_refs_in_calc(expr.ref, prefix_map) - return AllOf(updated_ref) - elif isinstance(expr, MethodCall): - updated_receiver = _update_measure_refs_in_calc(expr.receiver, prefix_map) - return MethodCall( - receiver=updated_receiver, - method=expr.method, - args=expr.args, - kwargs=expr.kwargs, - ) - elif isinstance(expr, BinOp): - # Recursively update left and right - updated_left = _update_measure_refs_in_calc(expr.left, prefix_map) - updated_right = _update_measure_refs_in_calc(expr.right, prefix_map) - return BinOp(op=expr.op, left=updated_left, right=updated_right) - else: - # Literal number or other - return as-is - return expr - - -def _extract_join_key_column_names(source: Relation) -> set[str]: - """ - Extract column names that ibis will merge (coalesce) during joins. - - Ibis only merges join-key columns when **both** sides of an equi-join share - the **same** column name (e.g., ``l.code == r.code``). When names differ - (e.g., ``l.carrier == r.code``), the right column gets a ``_right`` suffix - instead. We return only the intersection of left/right key names so that - ``_check_and_add_rename`` correctly detects columns that need renaming. - - Args: - source: The relation to search for join operations - - Returns: - Set of column names that ibis merges (same-name equi-join keys) - """ - join_keys: set[str] = set() - - def find_joins(node): - """Recursively find join operations and extract merged key columns.""" - if isinstance(node, SemanticJoinOp) and node.on: - try: - left_expr = node.left.to_expr() if hasattr(node.left, "to_expr") else node.left - right_expr = node.right.to_expr() if hasattr(node.right, "to_expr") else node.right - result = _extract_join_key_columns(node.on, left_expr, right_expr) - if result.is_success(): - # ibis merges only same-name equi-join columns - join_keys.update(result.left_columns & result.right_columns) - except (AttributeError, TypeError): - pass - - if hasattr(node, "left") and isinstance(node.left, Relation): - find_joins(node.left) - if hasattr(node, "right") and isinstance(node.right, Relation): - find_joins(node.right) - if hasattr(node, "source") and isinstance(node.source, Relation): - find_joins(node.source) - - find_joins(source) - return join_keys - - -def _build_column_rename_map( - all_roots: Sequence[SemanticTable], - field_accessor: callable, - source: Relation | None = None, -) -> dict[str, str]: - """ - Build a mapping of dimension names to their renamed column names in joined tables. - - When Ibis joins tables with duplicate column names, it renames columns from later - tables with '_right' suffix. However, columns used as join keys are merged and - NOT renamed, so we exclude them from the rename map. - - Uses graph_utils for generic traversal and the returns library for safe handling. - - Args: - all_roots: List of root semantic tables in join order - field_accessor: Function to get fields (dimensions) from a root - source: Optional source relation to extract join keys from - - Returns: - Dict mapping dimension names like 'airports.city' to renamed columns like 'city_right' - """ - # Build column index using graph_utils (returns Result) - from returns.result import Failure - - from ..graph_utils import build_column_index_from_roots, extract_column_from_dimension - - column_index_result = build_column_index_from_roots(all_roots) - if isinstance(column_index_result, Failure): - # If we can't build the index, return empty map (dimensions will use fallback behavior) - return {} - - column_index = column_index_result.value_or({}) - - # Extract join key columns to exclude from renaming - join_keys = _extract_join_key_column_names(source) if source else set() - - # Build a map from table name → actual ibis join depth by walking the - # join tree. The flat index in all_roots does NOT equal ibis join depth - # for nested joins (e.g. aircraft → aircraft_models inside a flights - # join tree), so we must compute it from the tree structure. - join_depth_map: dict[str, int] = {} - if source is not None: - join_depth_map = _build_join_depth_map(source) - - # Process dimensions and determine which need renamed columns - rename_map = {} - - for idx, root in enumerate(all_roots): - if not root.name: - continue - - fields_dict = field_accessor(root) - if not fields_dict: - continue - - root_tbl = root.to_untagged() - # Use the actual join depth if available, otherwise fall back to table_idx - effective_depth = join_depth_map.get(root.name, idx) - - for field_name, field_value in fields_dict.items(): - # Extract column name using graph_utils (returns Maybe) - column_maybe = extract_column_from_dimension(field_value, root_tbl) - - # Use Maybe pattern from returns library - column_maybe.bind_optional( - lambda base_column: _check_and_add_rename( # noqa: B023 - rename_map=rename_map, - base_column=base_column, - prefixed_name=f"{root.name}.{field_name}", # noqa: B023 - table_idx=idx, # noqa: B023 - column_index=column_index, - join_keys=join_keys, - join_depth=effective_depth, # noqa: B023 - ) - ) - - return rename_map - - -def _check_and_add_rename( - rename_map: dict[str, str], - base_column: str, - prefixed_name: str, - table_idx: int, - column_index: dict[str, list[int]], - join_keys: set[str], - join_depth: int | None = None, -) -> None: - """ - Check if a column needs renaming and add to rename map if so. - - ``table_idx`` is the flat index in ``all_roots`` used to detect - whether an earlier table has the same column. ``join_depth`` is - the actual ibis join depth (from ``_build_join_depth_map``) used - to compute the ``_right`` / ``_right2`` / … suffix. - - Args: - rename_map: Map to update with renames - base_column: The base column name - prefixed_name: The prefixed dimension name (e.g., 'airports.city') - table_idx: Flat index in all_roots (for conflict detection) - column_index: Index of column occurrences - join_keys: Set of column names used as join keys (these don't get renamed) - join_depth: Actual ibis join depth for suffix computation (defaults to table_idx) - """ - # Skip columns that are join keys - they get merged, not renamed - if base_column in join_keys: - return - - depth = join_depth if join_depth is not None else table_idx - - if base_column in column_index: - tables_with_column = column_index[base_column] - # Check if any table before this one (in flat order) has the same column - earlier_tables = [t for t in tables_with_column if t < table_idx] - if earlier_tables: - suffix = "_right" if depth <= 1 else f"_right{depth}" - rename_map[prefixed_name] = f"{base_column}{suffix}" - - -def _wrap_dimension_for_renamed_column(dimension: Dimension, renamed_column: str) -> Dimension: - """ - Wrap a dimension to access a renamed column in a joined table. - - Args: - dimension: The original dimension - renamed_column: The renamed column name (e.g., 'city_right') - - Returns: - A new Dimension that accesses the renamed column - """ - - # Create a new callable that accesses the renamed column - def renamed_accessor(table: ir.Table) -> ir.Value: - return table[renamed_column] - - # Return a new Dimension with the wrapped callable but same metadata - return Dimension( - expr=renamed_accessor, - description=dimension.description, - is_entity=dimension.is_entity, - is_time_dimension=dimension.is_time_dimension, - is_event_timestamp=dimension.is_event_timestamp, - smallest_time_grain=dimension.smallest_time_grain, - derived_dimensions=dimension.derived_dimensions, - ) - - -def _merge_fields_with_prefixing( - all_roots: Sequence[SemanticTable], - field_accessor: callable, - source: Relation | None = None, -) -> FrozenDict[str, Any]: - """ - Generic function to merge any type of fields (dimensions, measures) with prefixing. - - Args: - all_roots: List of SemanticTable root models - field_accessor: Function that takes a root and returns the fields dict (e.g. lambda r: r.dimensions) - source: Optional source relation to extract join keys from for proper column renaming - - Returns: - FrozenDict mapping field names (always prefixed with table name) to field values - """ - if not all_roots: - return FrozenDict() - - merged_fields = {} - - is_calc_measures = False - is_dimensions = False - if all_roots: - sample_fields = field_accessor(all_roots[0]) - if sample_fields: - from ..measure_scope import AllOf, BinOp, MeasureRef, MethodCall - - first_val = next(iter(sample_fields.values()), None) - is_calc_measures = isinstance( - first_val, - MeasureRef | AllOf | BinOp | MethodCall | int | float, - ) - is_dimensions = isinstance(first_val, Dimension) - - # For dimensions, build a column rename map to handle Ibis join conflicts - column_rename_map = {} - if is_dimensions: - column_rename_map = _build_column_rename_map(all_roots, field_accessor, source) - - for root in all_roots: - root_name = root.name - fields_dict = field_accessor(root) - - if is_calc_measures and root_name: - base_map = ( - {k: f"{root_name}.{k}" for k in root.get_measures()} - if hasattr(root, "get_measures") - else {} - ) - calc_map = ( - {k: f"{root_name}.{k}" for k in root.get_calculated_measures()} - if hasattr(root, "get_calculated_measures") - else {} - ) - prefix_map = {**base_map, **calc_map} - - for field_name, field_value in fields_dict.items(): - if root_name: - # Always use prefixed name with . separator - prefixed_name = f"{root_name}.{field_name}" - - # If it's a calculated measure, update internal MeasureRefs - if is_calc_measures: - field_value = _update_measure_refs_in_calc(field_value, prefix_map) - # If it's a dimension that needs column renaming, wrap the callable - elif is_dimensions and prefixed_name in column_rename_map: - field_value = _wrap_dimension_for_renamed_column( - field_value, column_rename_map[prefixed_name] - ) - - merged_fields[prefixed_name] = field_value - else: - # Fallback to original name if no root name - merged_fields[field_name] = field_value - - return FrozenDict(merged_fields) +from ._root_models import ( # noqa: E402 + _build_column_rename_map, + _build_join_depth_map, + _check_and_add_rename, + _dimension_only_source_table, + _extract_join_key_column_names, + _find_all_root_models, + _find_root_model, + _merge_fields_with_prefixing, + _update_measure_refs_in_calc, + _wrap_dimension_for_renamed_column, +) # Column-extraction classes & helpers were moved to ._column_extraction. # Re-imported here so internal callers in this module keep working. diff --git a/src/boring_semantic_layer/ops/_normalize.py b/src/boring_semantic_layer/ops/_normalize.py new file mode 100644 index 0000000..9511bb4 --- /dev/null +++ b/src/boring_semantic_layer/ops/_normalize.py @@ -0,0 +1,105 @@ +"""Argument normalizers for join predicates and dimension/measure names. + +Used by the user-facing ``SemanticTable`` API to accept strings, simple +``Deferred`` expressions (``_.col``), and callable predicates uniformly. +""" + +from __future__ import annotations + +from ibis.common.deferred import Deferred + +from ._values import _is_deferred + + +def _normalize_to_name(arg: str | Deferred) -> str: + """Convert a string or simple ``_.name`` Deferred to a plain string name. + + Accepts a plain string (returned as-is) or a Deferred whose resolver is a + simple attribute access on the top-level ``_`` variable (e.g. ``_.origin``). + + Complex expressions like ``_.distance.sum()`` or ``_.a.b`` are rejected + with a ``TypeError``. + """ + if isinstance(arg, str): + return arg + + # Duck-type: works for both ibis and xorq Deferred objects + resolver = getattr(arg, "_resolver", None) + if resolver is None: + raise TypeError( + f"Expected a string name or Deferred expression (_.name), got {type(arg).__name__}" + ) + + obj = getattr(resolver, "obj", None) + + # Try attribute access first (_.name -> Attr resolver with .name) + name_wrapper = getattr(resolver, "name", None) + + # Fall back to getitem access (_["name"] -> Item resolver with .indexer) + if name_wrapper is None: + name_wrapper = getattr(resolver, "indexer", None) + + if name_wrapper is None or obj is None: + raise TypeError( + f"Only simple Deferred expressions like _.name or _['name'] are supported " + f"as positional arguments, got: {arg!r}" + ) + + # Reject chained access like _.a.b (obj would itself have an .obj attr) + if getattr(obj, "obj", None) is not None: + raise TypeError( + f"Only simple Deferred expressions like _.name or _['name'] are supported " + f"as positional arguments, got: {arg!r}" + ) + + # Attr.name / Item.indexer is a Just wrapper; unwrap via .value + raw_name = getattr(name_wrapper, "value", name_wrapper) + if not isinstance(raw_name, str): + raise TypeError(f"Could not extract string name from Deferred expression: {arg!r}") + + return raw_name + + +def _normalize_join_predicate(on): + """Normalize a join predicate to a two-argument callable. + + Accepts: + - ``str`` – equi-join on a column present in both sides + - ``Deferred`` (``_.col``) – same, after extracting the name + - ``list[str | Deferred]`` – compound equi-join on multiple columns + - ``callable`` (non-Deferred) – returned as-is (existing lambda API) + - ``None`` – returned as-is (for cross joins) + """ + if on is None: + return on + + if isinstance(on, str): + name = on + return lambda left, right: getattr(left, name) == getattr(right, name) + + if _is_deferred(on): + name = _normalize_to_name(on) + return lambda left, right: getattr(left, name) == getattr(right, name) + + if isinstance(on, (list, tuple)): + names = [_normalize_to_name(item) for item in on] + if len(names) == 1: + name = names[0] + return lambda left, right: getattr(left, name) == getattr(right, name) + + def _compound_predicate(left, right): + from functools import reduce + from operator import and_ + + preds = [getattr(left, n) == getattr(right, n) for n in names] + return reduce(and_, preds) + + return _compound_predicate + + if callable(on): + return on + + raise TypeError( + f"join `on` must be a string, Deferred (_.col), list of strings/Deferred, " + f"or a callable, got {type(on).__name__}" + ) diff --git a/src/boring_semantic_layer/ops/_root_models.py b/src/boring_semantic_layer/ops/_root_models.py new file mode 100644 index 0000000..f88644f --- /dev/null +++ b/src/boring_semantic_layer/ops/_root_models.py @@ -0,0 +1,483 @@ +"""Root-model traversal and join-aware field merging. + +These helpers walk a relation tree to find the underlying ``SemanticTableOp`` +roots (one per leaf table) and merge their dimensions/measures with proper +table prefixing. They also handle the ibis quirk of renaming non-key columns +on the right side of a join with ``_right`` / ``_right2`` / … suffixes. + +``SemanticTableOp`` and ``SemanticJoinOp`` are imported lazily inside each +function to avoid a circular module dependency with ``_core.py``. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +from ibis.expr import types as ir +from ibis.expr.operations.relations import Relation + +from .._xorq import FrozenDict +from ._column_extraction import _extract_columns_from_callable, _extract_join_key_columns +from ._values import Dimension + + +def _find_root_model(node: Any): + """Find root SemanticTableOp in the operation tree.""" + from ._core import SemanticTableOp + + cur = node + while cur is not None: + if isinstance(cur, SemanticTableOp): + return cur + parent = getattr(cur, "source", None) + cur = parent + return None + + +def _find_all_root_models(node: Any) -> tuple[Any, ...]: + """Find all root SemanticTableOps in the operation tree (handles joins with multiple roots).""" + from ._core import SemanticTableOp + + if isinstance(node, SemanticTableOp): + return [node] + + roots = [] + + if hasattr(node, "left") and hasattr(node, "right"): + roots.extend(_find_all_root_models(node.left)) + roots.extend(_find_all_root_models(node.right)) + elif hasattr(node, "source") and node.source is not None: + roots.extend(_find_all_root_models(node.source)) + + return roots + + +def _dimension_only_source_table( + keys: tuple[str, ...], + all_roots: Sequence[Any], + filters: tuple, +): + """Check if a dimension-only query can be routed to a single source table. + + When all requested dimension keys share a single table prefix and that + prefix maps to a root model whose dimensions cover every key, we can + bypass the join and query the dimension table directly. This ensures + dimension members with no matching fact rows are still returned. + + *filters* are the ``_CallableWrapper`` predicates collected between the + aggregate and the underlying join. Filters whose column references all + belong to the target table are forwarded; if any filter references columns + outside the target table the shortcut is disabled. + + Returns ``(root_op, unprefixed_keys, applicable_filters)`` or ``None``. + """ + from ._core import _to_untagged, _unwrap + + if not keys: + return None + + prefixes: set[str] = set() + unprefixed: list[str] = [] + for key in keys: + if "." not in key: + return None # Non-prefixed key — can't determine source + prefix, name = key.split(".", 1) + prefixes.add(prefix) + unprefixed.append(name) + + if len(prefixes) != 1: + return None # Keys span multiple tables + + target_prefix = next(iter(prefixes)) + + for root in all_roots: + if root.name == target_prefix: + root_dims = root.get_dimensions() + if all(k in root_dims for k in unprefixed): + # Validate that every filter only touches columns present + # on the target dimension table. If any filter references + # columns from other tables we cannot use the shortcut. + if filters: + tbl = _to_untagged(root) + tbl_cols = frozenset(tbl.columns) | frozenset(root_dims) + for flt in filters: + fn = _unwrap(flt) if hasattr(flt, "unwrap") else flt + extraction = _extract_columns_from_callable(fn, tbl) + if extraction.extraction_failed: + return None # Can't determine — bail out + if not extraction.columns <= tbl_cols: + return None # References columns outside target + return root, unprefixed, filters + + return None + + +def _build_join_depth_map(node: Any) -> dict[str, int]: + """Map each leaf table name to its actual ibis rname depth. + + ``SemanticJoinOp.to_untagged`` calls ``_join_depth`` to determine the + rname suffix for each join level. ``_join_depth`` counts the number + of ``SemanticJoinOp`` ancestors on the *left* spine. The right child + at depth *d* gets ``rname = _rname_for_depth(d)``. + + For nested subtrees on the right side of a join, ibis applies the + inner subtree's rname independently. So ``aircraft_models`` at inner + depth 1 gets ``_right``, not ``_right3`` even if the outer depth is 3. + + This function mirrors ``_join_depth`` logic: walk down the left spine, + recording the right child's depth at each level. If the right child is + itself a join tree, recurse to get inner depths for its leaves. + """ + from ._core import SemanticJoinOp, SemanticTableOp + + depth_map: dict[str, int] = {} + + def _record_leaf(n, depth: int): + """Record a leaf table at the given depth.""" + if isinstance(n, SemanticTableOp): + name = n.name + if name and name not in depth_map: + depth_map[name] = depth + + def _walk_join_spine(n): + """Walk the left spine of a join tree, recording depths.""" + if not isinstance(n, SemanticJoinOp): + # Leftmost leaf: depth 0 (root, never renamed) + _record_leaf(n, 0) + return + + depth = SemanticJoinOp._join_depth(n) + # The right child is joined at this depth + right = n.right + if isinstance(right, SemanticJoinOp): + # Right is a subtree — its leaves get inner depths + inner_map = _build_join_depth_map(right) + for tname, idepth in inner_map.items(): + if tname not in depth_map: + if idepth == 0: + # Leftmost leaf of subtree sits at the outer depth + # (it receives the outer rname suffix if conflicting) + depth_map[tname] = depth + else: + # Inner leaves keep their inner depth (inner rname) + depth_map[tname] = idepth + else: + _record_leaf(right, depth) + + # Recurse down the left spine + _walk_join_spine(n.left) + + _walk_join_spine(node) + return depth_map + + +def _update_measure_refs_in_calc(expr, prefix_map: dict[str, str]): + """ + Recursively update MeasureRef names in a calculated measure expression. + + Args: + expr: A MeasureExpr (MeasureRef, AllOf, BinOp, MethodCall, or literal) + prefix_map: Mapping from old name to new prefixed name + + Returns: + Updated expression with prefixed MeasureRef names + """ + from ..measure_scope import AllOf, BinOp, MeasureRef, MethodCall + + if isinstance(expr, MeasureRef): + # Update the measure reference name if it's in the map + new_name = prefix_map.get(expr.name, expr.name) + return MeasureRef(new_name) + elif isinstance(expr, AllOf): + # Update the inner MeasureRef + updated_ref = _update_measure_refs_in_calc(expr.ref, prefix_map) + return AllOf(updated_ref) + elif isinstance(expr, MethodCall): + updated_receiver = _update_measure_refs_in_calc(expr.receiver, prefix_map) + return MethodCall( + receiver=updated_receiver, + method=expr.method, + args=expr.args, + kwargs=expr.kwargs, + ) + elif isinstance(expr, BinOp): + # Recursively update left and right + updated_left = _update_measure_refs_in_calc(expr.left, prefix_map) + updated_right = _update_measure_refs_in_calc(expr.right, prefix_map) + return BinOp(op=expr.op, left=updated_left, right=updated_right) + else: + # Literal number or other - return as-is + return expr + + +def _extract_join_key_column_names(source: Relation) -> set[str]: + """ + Extract column names that ibis will merge (coalesce) during joins. + + Ibis only merges join-key columns when **both** sides of an equi-join share + the **same** column name (e.g., ``l.code == r.code``). When names differ + (e.g., ``l.carrier == r.code``), the right column gets a ``_right`` suffix + instead. We return only the intersection of left/right key names so that + ``_check_and_add_rename`` correctly detects columns that need renaming. + + Args: + source: The relation to search for join operations + + Returns: + Set of column names that ibis merges (same-name equi-join keys) + """ + from ._core import SemanticJoinOp + + join_keys: set[str] = set() + + def find_joins(node): + """Recursively find join operations and extract merged key columns.""" + if isinstance(node, SemanticJoinOp) and node.on: + try: + left_expr = node.left.to_expr() if hasattr(node.left, "to_expr") else node.left + right_expr = node.right.to_expr() if hasattr(node.right, "to_expr") else node.right + result = _extract_join_key_columns(node.on, left_expr, right_expr) + if result.is_success(): + # ibis merges only same-name equi-join columns + join_keys.update(result.left_columns & result.right_columns) + except (AttributeError, TypeError): + pass + + if hasattr(node, "left") and isinstance(node.left, Relation): + find_joins(node.left) + if hasattr(node, "right") and isinstance(node.right, Relation): + find_joins(node.right) + if hasattr(node, "source") and isinstance(node.source, Relation): + find_joins(node.source) + + find_joins(source) + return join_keys + + +def _build_column_rename_map( + all_roots: Sequence[Any], + field_accessor: callable, + source: Relation | None = None, +) -> dict[str, str]: + """ + Build a mapping of dimension names to their renamed column names in joined tables. + + When Ibis joins tables with duplicate column names, it renames columns from later + tables with '_right' suffix. However, columns used as join keys are merged and + NOT renamed, so we exclude them from the rename map. + + Uses graph_utils for generic traversal and the returns library for safe handling. + + Args: + all_roots: List of root semantic tables in join order + field_accessor: Function to get fields (dimensions) from a root + source: Optional source relation to extract join keys from + + Returns: + Dict mapping dimension names like 'airports.city' to renamed columns like 'city_right' + """ + # Build column index using graph_utils (returns Result) + from returns.result import Failure + + from ..graph_utils import build_column_index_from_roots, extract_column_from_dimension + + column_index_result = build_column_index_from_roots(all_roots) + if isinstance(column_index_result, Failure): + # If we can't build the index, return empty map (dimensions will use fallback behavior) + return {} + + column_index = column_index_result.value_or({}) + + # Extract join key columns to exclude from renaming + join_keys = _extract_join_key_column_names(source) if source else set() + + # Build a map from table name → actual ibis join depth by walking the + # join tree. The flat index in all_roots does NOT equal ibis join depth + # for nested joins (e.g. aircraft → aircraft_models inside a flights + # join tree), so we must compute it from the tree structure. + join_depth_map: dict[str, int] = {} + if source is not None: + join_depth_map = _build_join_depth_map(source) + + # Process dimensions and determine which need renamed columns + rename_map = {} + + for idx, root in enumerate(all_roots): + if not root.name: + continue + + fields_dict = field_accessor(root) + if not fields_dict: + continue + + root_tbl = root.to_untagged() + # Use the actual join depth if available, otherwise fall back to table_idx + effective_depth = join_depth_map.get(root.name, idx) + + for field_name, field_value in fields_dict.items(): + # Extract column name using graph_utils (returns Maybe) + column_maybe = extract_column_from_dimension(field_value, root_tbl) + + # Use Maybe pattern from returns library + column_maybe.bind_optional( + lambda base_column: _check_and_add_rename( # noqa: B023 + rename_map=rename_map, + base_column=base_column, + prefixed_name=f"{root.name}.{field_name}", # noqa: B023 + table_idx=idx, # noqa: B023 + column_index=column_index, + join_keys=join_keys, + join_depth=effective_depth, # noqa: B023 + ) + ) + + return rename_map + + +def _check_and_add_rename( + rename_map: dict[str, str], + base_column: str, + prefixed_name: str, + table_idx: int, + column_index: dict[str, list[int]], + join_keys: set[str], + join_depth: int | None = None, +) -> None: + """ + Check if a column needs renaming and add to rename map if so. + + ``table_idx`` is the flat index in ``all_roots`` used to detect + whether an earlier table has the same column. ``join_depth`` is + the actual ibis join depth (from ``_build_join_depth_map``) used + to compute the ``_right`` / ``_right2`` / … suffix. + + Args: + rename_map: Map to update with renames + base_column: The base column name + prefixed_name: The prefixed dimension name (e.g., 'airports.city') + table_idx: Flat index in all_roots (for conflict detection) + column_index: Index of column occurrences + join_keys: Set of column names used as join keys (these don't get renamed) + join_depth: Actual ibis join depth for suffix computation (defaults to table_idx) + """ + # Skip columns that are join keys - they get merged, not renamed + if base_column in join_keys: + return + + depth = join_depth if join_depth is not None else table_idx + + if base_column in column_index: + tables_with_column = column_index[base_column] + # Check if any table before this one (in flat order) has the same column + earlier_tables = [t for t in tables_with_column if t < table_idx] + if earlier_tables: + suffix = "_right" if depth <= 1 else f"_right{depth}" + rename_map[prefixed_name] = f"{base_column}{suffix}" + + +def _wrap_dimension_for_renamed_column(dimension: Dimension, renamed_column: str) -> Dimension: + """ + Wrap a dimension to access a renamed column in a joined table. + + Args: + dimension: The original dimension + renamed_column: The renamed column name (e.g., 'city_right') + + Returns: + A new Dimension that accesses the renamed column + """ + + # Create a new callable that accesses the renamed column + def renamed_accessor(table: ir.Table) -> ir.Value: + return table[renamed_column] + + # Return a new Dimension with the wrapped callable but same metadata + return Dimension( + expr=renamed_accessor, + description=dimension.description, + is_entity=dimension.is_entity, + is_time_dimension=dimension.is_time_dimension, + is_event_timestamp=dimension.is_event_timestamp, + smallest_time_grain=dimension.smallest_time_grain, + derived_dimensions=dimension.derived_dimensions, + ) + + +def _merge_fields_with_prefixing( + all_roots: Sequence[Any], + field_accessor: callable, + source: Relation | None = None, +) -> FrozenDict[str, Any]: + """ + Generic function to merge any type of fields (dimensions, measures) with prefixing. + + Args: + all_roots: List of SemanticTable root models + field_accessor: Function that takes a root and returns the fields dict (e.g. lambda r: r.dimensions) + source: Optional source relation to extract join keys from for proper column renaming + + Returns: + FrozenDict mapping field names (always prefixed with table name) to field values + """ + if not all_roots: + return FrozenDict() + + merged_fields = {} + + is_calc_measures = False + is_dimensions = False + if all_roots: + sample_fields = field_accessor(all_roots[0]) + if sample_fields: + from ..measure_scope import AllOf, BinOp, MeasureRef, MethodCall + + first_val = next(iter(sample_fields.values()), None) + is_calc_measures = isinstance( + first_val, + MeasureRef | AllOf | BinOp | MethodCall | int | float, + ) + is_dimensions = isinstance(first_val, Dimension) + + # For dimensions, build a column rename map to handle Ibis join conflicts + column_rename_map = {} + if is_dimensions: + column_rename_map = _build_column_rename_map(all_roots, field_accessor, source) + + for root in all_roots: + root_name = root.name + fields_dict = field_accessor(root) + + if is_calc_measures and root_name: + base_map = ( + {k: f"{root_name}.{k}" for k in root.get_measures()} + if hasattr(root, "get_measures") + else {} + ) + calc_map = ( + {k: f"{root_name}.{k}" for k in root.get_calculated_measures()} + if hasattr(root, "get_calculated_measures") + else {} + ) + prefix_map = {**base_map, **calc_map} + + for field_name, field_value in fields_dict.items(): + if root_name: + # Always use prefixed name with . separator + prefixed_name = f"{root_name}.{field_name}" + + # If it's a calculated measure, update internal MeasureRefs + if is_calc_measures: + field_value = _update_measure_refs_in_calc(field_value, prefix_map) + # If it's a dimension that needs column renaming, wrap the callable + elif is_dimensions and prefixed_name in column_rename_map: + field_value = _wrap_dimension_for_renamed_column( + field_value, column_rename_map[prefixed_name] + ) + + merged_fields[prefixed_name] = field_value + else: + # Fallback to original name if no root name + merged_fields[field_name] = field_value + + return FrozenDict(merged_fields) diff --git a/src/boring_semantic_layer/ops/_xorq_compat.py b/src/boring_semantic_layer/ops/_xorq_compat.py new file mode 100644 index 0000000..f4e9fbf --- /dev/null +++ b/src/boring_semantic_layer/ops/_xorq_compat.py @@ -0,0 +1,127 @@ +"""xorq/ibis compatibility shims. + +These functions bridge plain ibis and xorq's vendored ibis: convert one +to the other (``_ensure_xorq_table``), patch SortKey shape differences +(``_patch_xorq_sortkey_compat``), and rebind ``DatabaseTable`` nodes so +expressions composed from separately-converted tables share a single +backend (``_rebind_to_backend`` / ``_rebind_to_canonical_backend``). +""" + +from __future__ import annotations + +import logging + +logger = logging.getLogger(__name__) + + +def _patch_xorq_sortkey_compat(): + """Register a map_ibis handler so ibis SortKey → xorq SortKey. + + ibis 11 uses ``SortKey.expr``, ibis 12 renamed it to ``SortKey.arg``, + while xorq's vendored ibis keeps ``SortKey.expr``. Handle both. + """ + from ibis.expr.operations.sortkeys import SortKey as IbisSortKey + + from .._xorq import SortKey as XorqSortKey, map_ibis + + if IbisSortKey in map_ibis.registry: + return # already patched + + @map_ibis.register(IbisSortKey) + def _map_sort_key(val, kwargs=None): + # ibis 12 uses .arg, ibis 11 uses .expr + sort_expr = getattr(val, "arg", None) or getattr(val, "expr") + return XorqSortKey( + expr=map_ibis(sort_expr, None), + ascending=val.ascending, + nulls_first=val.nulls_first, + ) + + +def _ensure_xorq_table(table): + """Convert plain ibis Table to xorq-vendored ibis if possible. + + This is the single boundary between user-supplied ibis tables and + BSL's internal xorq representation. ``SemanticModel`` calls it once + at construction so internal code paths can assume xorq tables when + the backend is supported, and a plain ibis fallback otherwise. + + Falls back to returning the plain ibis table when the backend is not + supported by xorq (e.g. Databricks). Idempotent: calling it on a + xorq-vendored table is a cheap no-op. + """ + _patch_xorq_sortkey_compat() + if "xorq.vendor.ibis" not in type(table).__module__: + try: + from .._xorq import from_ibis + + return from_ibis(table) + except Exception: + # Backend isn't supported by xorq's map_ibis registry (e.g. + # Databricks). Fall back so plain-ibis paths can still execute. + logger.debug( + "from_ibis failed for %s; using plain ibis table", + type(table).__module__, + exc_info=True, + ) + return table + return table + + +def _rebind_to_backend(expr, target_backend): + """Rebind every ``DatabaseTable`` op in *expr* to *target_backend*. + + Low-level primitive shared with ``serialization.reconstruct``. + No-op on plain ibis expressions or when xorq is unavailable for any + reason; callers must pass a xorq-vendored ``target_backend``. + """ + try: + from .._xorq import relations as xorq_rel + except ImportError: + return expr + + def _recreate(op, _kwargs, **overrides): + kwargs = dict(zip(op.__argnames__, op.__args__, strict=False)) + if _kwargs: + kwargs.update(_kwargs) + kwargs.update(overrides) + return op.__recreate__(kwargs) + + def replacer(op, _kwargs): + if isinstance(op, xorq_rel.DatabaseTable) and op.source is not target_backend: + return _recreate(op, _kwargs, source=target_backend) + if _kwargs: + return _recreate(op, _kwargs) + return op + + return expr.op().replace(replacer).to_expr() + + +def _rebind_to_canonical_backend(expr): + """Rebind divergent ``DatabaseTable`` backends in *expr* to share one. + + ``from_ibis()`` creates a distinct ``Backend`` per call, so expressions + built by composing separately-converted tables contain multiple + backends. Picking the first ``DatabaseTable``'s source as canonical + and rebinding the rest eliminates "Multiple backends found" errors + at execution time. + + No-op on plain ibis expressions (not xorq-vendored). + """ + try: + from .._xorq import relations as xorq_rel, walk_nodes + except ImportError: + return expr + + try: + db_tables = list(walk_nodes((xorq_rel.DatabaseTable,), expr)) + except Exception: + # walk_nodes can't traverse plain ibis trees; treat as no-op. + logger.debug("walk_nodes failed on plain ibis expr", exc_info=True) + return expr + + canonical = db_tables[0].source if db_tables else None + if canonical is None: + return expr + + return _rebind_to_backend(expr, canonical) From 83c2fbb30d0de505ee7004da282a00fe3284c561 Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Mon, 4 May 2026 07:44:48 -0400 Subject: [PATCH 3/6] refactor(ops): extract format/repr helpers and SemanticIndexOp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two more chunks out of _core.py: - _format.py (159 lines): _collect_chain, _format_op_summary, _format_root, _semantic_repr — the helpers that build the multi-line pipeline repr used by every Semantic*Op.__repr__. All Op classes are imported lazily inside each helper to dodge the circular dependency. - _index_op.py (343 lines): SemanticIndexOp plus its 6 fragment builders (_get_field_type_str, _get_weight_expr, _build_string_index_fragment, _build_numeric_index_fragment, _resolve_selector, _get_fields_to_index). Self-contained except for lazy imports of _get_merged_fields/_to_untagged from _core. _core.py is now 3,775 lines. All extractions are pure moves with re-imports back so internal call sites are unchanged. 930 tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/boring_semantic_layer/ops/_core.py | 431 +-------------------- src/boring_semantic_layer/ops/_format.py | 160 ++++++++ src/boring_semantic_layer/ops/_index_op.py | 321 +++++++++++++++ 3 files changed, 491 insertions(+), 421 deletions(-) create mode 100644 src/boring_semantic_layer/ops/_format.py create mode 100644 src/boring_semantic_layer/ops/_index_op.py diff --git a/src/boring_semantic_layer/ops/_core.py b/src/boring_semantic_layer/ops/_core.py index 39d64e6..af90eca 100644 --- a/src/boring_semantic_layer/ops/_core.py +++ b/src/boring_semantic_layer/ops/_core.py @@ -126,138 +126,7 @@ def _unwrap(wrapped: Any) -> Any: return wrapped.unwrap if isinstance(wrapped, _CallableWrapper) else wrapped -def _collect_chain(op: Relation) -> list[Relation]: - """Walk op.source (or op.left for joins) back to root, return list from root to current.""" - chain = [op] - current = op - while True: - if hasattr(current, "source") and current.source is not None: - chain.append(current.source) - current = current.source - elif hasattr(current, "left") and current.left is not None: - chain.append(current.left) - current = current.left - else: - break - chain.reverse() - return chain - - -def _format_op_summary(op: Relation) -> str: - """Return a one-line summary string for a non-root semantic op.""" - # Import here to avoid circular imports at module level - cls = type(op).__name__ - - if isinstance(op, SemanticFilterOp): - predicate = object.__getattribute__(op, "predicate") - pred_name = "" - if hasattr(predicate, "__name__"): - pred_name = predicate.__name__ - elif hasattr(predicate, "unwrap"): - unwrapped = predicate.unwrap - if hasattr(unwrapped, "__name__"): - pred_name = unwrapped.__name__ - return f"Filter(\u03bb {pred_name})" - - if isinstance(op, SemanticMutateOp): - post = object.__getattribute__(op, "post") - cols = list(post.keys()) - return f"Mutate({', '.join(cols)})" - - if isinstance(op, SemanticGroupByOp): - keys = object.__getattribute__(op, "keys") - return f"GroupBy({', '.join(keys)})" - - if isinstance(op, SemanticAggregateOp): - aggs = object.__getattribute__(op, "aggs") - agg_names = list(aggs.keys()) - return f"Aggregate({', '.join(agg_names)})" - - if isinstance(op, SemanticOrderByOp): - keys = object.__getattribute__(op, "keys") - key_strs = [k if isinstance(k, str) else repr(k) for k in keys] - return f"OrderBy({', '.join(key_strs)})" - - if isinstance(op, SemanticLimitOp): - return f"Limit({op.n})" - - if isinstance(op, SemanticProjectOp): - fields = object.__getattribute__(op, "fields") - return f"Project({', '.join(fields)})" - - if isinstance(op, SemanticUnnestOp): - column = object.__getattribute__(op, "column") - return f"Unnest({column})" - - if isinstance(op, SemanticJoinOp): - how = object.__getattribute__(op, "how") - right = object.__getattribute__(op, "right") - right_name = "" - if isinstance(right, SemanticTableOp): - right_name = object.__getattribute__(right, "name") or "" - if not right_name: - # Try to find a root name from right side - roots = _find_all_root_models(right) - if roots: - right_name = object.__getattribute__(roots[0], "name") or "" - if right_name: - return f"Join({how}, right={right_name})" - return f"Join({how})" - - if isinstance(op, SemanticIndexOp): - parts = [] - selector = object.__getattribute__(op, "selector") - by = object.__getattribute__(op, "by") - sample = object.__getattribute__(op, "sample") - if selector is not None: - if isinstance(selector, tuple): - parts.append(", ".join(selector)) - else: - parts.append(str(selector)) - if by is not None: - parts.append(f"by={by}") - if sample is not None: - parts.append(f"sample={sample}") - return f"Index({', '.join(parts)})" - - # Fallback for unknown op types - return cls.replace("Semantic", "").replace("Op", "") - - -def _format_root(root_op: SemanticTableOp) -> str: - """Format a SemanticTableOp root using the fmt registry from format.py.""" - from boring_semantic_layer.format import fmt - - try: - return fmt(root_op) - except Exception: - # Fallback if format module isn't available - name = object.__getattribute__(root_op, "name") - return f"SemanticTable: {name}" if name else "SemanticTable" - - -def _semantic_repr(op: Relation) -> str: - chain = _collect_chain(op) - - # Find the root (first element should be a SemanticTableOp) - root = chain[0] - if isinstance(root, SemanticTableOp): - lines = [_format_root(root)] - else: - # Fallback: no SemanticTableOp root found - from ibis.expr.format import pretty - - try: - return pretty(op) - except Exception: - return object.__repr__(op) - - # Append pipeline steps - for step in chain[1:]: - if not isinstance(step, SemanticTableOp): - lines.append(f"-> {_format_op_summary(step)}") - - return "\n".join(lines) +from ._format import _collect_chain, _format_op_summary, _format_root, _semantic_repr # noqa: E402, F401 def _make_schema(fields_dict: dict[str, str]): @@ -3866,295 +3735,15 @@ def get_calculated_measures(self) -> Mapping[str, Any]: return self.source.get_calculated_measures() -def _get_field_type_str(field_type: Any) -> str: - return ( - "string" - if field_type.is_string() - else "number" - if field_type.is_numeric() - else "date" - if field_type.is_temporal() - else str(field_type) - ) - - -def _get_weight_expr( - base_tbl: Any, - by_measure: str | None, - all_roots: list, - is_string: bool, -) -> Any: - from .._xorq import api as xo - - if not by_measure: - return xo._.count() - - merged_measures = _get_merged_fields(all_roots, "measures") - return ( - merged_measures[by_measure](base_tbl) if by_measure in merged_measures else xo._.count() - ) - - -def _build_string_index_fragment( - base_tbl: Any, - field_expr: Any, - field_name: str, - field_path: str, - type_str: str, - weight_expr: Any, -) -> Any: - from .._xorq import api as xo - - return ( - base_tbl.group_by(field_expr.name("value")) - .aggregate(weight=weight_expr) - .select( - fieldName=xo.literal(field_name.split(".")[-1]), - fieldPath=xo.literal(field_path), - fieldType=xo.literal(type_str), - fieldValue=xo._["value"].cast("string"), - weight=xo._["weight"], - ) - ) - - -def _build_numeric_index_fragment( - base_tbl: Any, - field_expr: Any, - field_name: str, - field_path: str, - type_str: str, - weight_expr: Any, -) -> Any: - from .._xorq import api as xo - - return ( - base_tbl.select(field_expr.name("value")) - .filter(xo._["value"].notnull()) - .aggregate( - min_val=xo._["value"].min(), - max_val=xo._["value"].max(), - weight=weight_expr, - ) - .select( - fieldName=xo.literal(field_name.split(".")[-1]), - fieldPath=xo.literal(field_path), - fieldType=xo.literal(type_str), - fieldValue=( - xo._["min_val"].cast("string") + " to " + xo._["max_val"].cast("string") - ), - weight=xo._["weight"], - ) - ) - - -def _resolve_selector( - selector: str | list[str] | Callable | None, - base_tbl: ir.Table, -) -> tuple[str, ...]: - if selector is None: - return tuple(base_tbl.columns) - try: - selected = base_tbl.select(selector) - return tuple(selected.columns) - except Exception: - return [] - - -def _get_fields_to_index( - selector: str | list[str] | Callable | None, - merged_dimensions: dict, - base_tbl: ir.Table, -) -> tuple[str, ...]: - if selector is None: - selector = s.all() - - raw_fields = _resolve_selector(selector, base_tbl) - - if not raw_fields: - result = list(merged_dimensions.keys()) - result.extend(col for col in base_tbl.columns if col not in result) - else: - result = [col for col in raw_fields if col in merged_dimensions or col in base_tbl.columns] - - return result - - -class SemanticIndexOp(Relation): - source: Relation - selector: str | list[str] | tuple[str, ...] | Callable | None - by: str | None = None - sample: int | None = None - - def __init__( - self, - source: Relation, - selector: str | list[str] | tuple[str, ...] | Callable | None = None, - by: str | None = None, - sample: int | None = None, - ) -> None: - # Validate sample parameter - if sample is not None and sample <= 0: - raise ValueError(f"sample must be positive, got {sample}") - - # Validate 'by' measure exists if provided - if by is not None: - all_roots = _find_all_root_models(source) - if all_roots: - merged_measures = _get_merged_fields(all_roots, "measures") - if by not in merged_measures: - available = list(merged_measures.keys()) - raise KeyError( - f"Unknown measure '{by}' for weight calculation. " - f"Available measures: {', '.join(available) or 'none'}", - ) - - # Convert selector to tuple if it's a list (Ibis requires hashable types) - hashable_selector = tuple(selector) if isinstance(selector, list) else selector - - super().__init__( - source=Relation.__coerce__(source), - selector=hashable_selector, - by=by, - sample=sample, - ) - - def __repr__(self) -> str: - return _semantic_repr(self) - - @property - def values(self) -> FrozenOrderedDict[str, Any]: - from .._xorq import api as xo - - return FrozenOrderedDict( - { - "fieldName": xo.literal("").op(), - "fieldPath": xo.literal("").op(), - "fieldType": xo.literal("").op(), - "fieldValue": xo.literal("").op(), - "weight": xo.literal(0).op(), - }, - ) - - @property - def schema(self) -> Schema: - return Schema( - { - "fieldName": "string", - "fieldPath": "string", - "fieldType": "string", - "fieldValue": "string", - "weight": "int64", - }, - ) - - @property - def keys(self) -> tuple[str, ...]: - return ("fieldValue", "fieldName", "fieldPath", "fieldType") - - @property - def aggs(self) -> dict[str, Any]: - return {"weight": lambda t: t.weight} - - def to_untagged(self): - all_roots = _find_all_root_models(self.source) - base_tbl = ( - _to_untagged(self.source).limit(self.sample) - if self.sample - else _to_untagged(self.source) - ) - - merged_dimensions = _get_merged_fields(all_roots, "dimensions") - fields_to_index = _get_fields_to_index( - self.selector, - merged_dimensions, - base_tbl, - ) - - if not fields_to_index: - from .._xorq import api as xo - - return xo.memtable( - { - "fieldName": [], - "fieldPath": [], - "fieldType": [], - "fieldValue": [], - "weight": [], - }, - ) - - def build_fragment(field_name: str) -> Any: - field_expr = ( - merged_dimensions[field_name](base_tbl) - if field_name in merged_dimensions - else base_tbl[field_name] - ) - field_type = field_expr.type() - type_str = _get_field_type_str(field_type) - weight_expr = _get_weight_expr( - base_tbl, - self.by, - all_roots, - field_type.is_string(), - ) - - return ( - _build_string_index_fragment( - base_tbl, - field_expr, - field_name, - field_name, - type_str, - weight_expr, - ) - if field_type.is_string() or not field_type.is_numeric() - else _build_numeric_index_fragment( - base_tbl, - field_expr, - field_name, - field_name, - type_str, - weight_expr, - ) - ) - - fragments = [build_fragment(f) for f in fields_to_index] - return reduce(lambda acc, frag: acc.union(frag), fragments[1:], fragments[0]) - - def filter(self, predicate: Callable) -> SemanticFilter: - from ..expr import SemanticFilter - - return SemanticFilter(source=self, predicate=predicate) - - def order_by(self, *keys: str | ir.Value | Callable) -> SemanticOrderBy: - from ..expr import SemanticOrderBy - - return SemanticOrderBy(source=self, keys=keys) - - def limit(self, n: int, offset: int = 0) -> SemanticLimit: - from ..expr import SemanticLimit - - return SemanticLimit(source=self, n=n, offset=offset) - - def execute(self): - return _rebind_to_canonical_backend(self.to_untagged()).execute() - - def as_expr(self): - """Return self as expression.""" - return self - - def compile(self, **kwargs): - return self.to_untagged().compile(**kwargs) - - def sql(self, **kwargs): - return ibis.to_sql(self.to_untagged(), **kwargs) - - def __getitem__(self, key): - return self.to_untagged()[key] - - def pipe(self, func, *args, **kwargs): - return func(self, *args, **kwargs) +from ._index_op import ( # noqa: E402 + SemanticIndexOp, + _build_numeric_index_fragment, + _build_string_index_fragment, + _get_field_type_str, + _get_fields_to_index, + _get_weight_expr, + _resolve_selector, +) from ._root_models import ( # noqa: E402 diff --git a/src/boring_semantic_layer/ops/_format.py b/src/boring_semantic_layer/ops/_format.py new file mode 100644 index 0000000..5c450b5 --- /dev/null +++ b/src/boring_semantic_layer/ops/_format.py @@ -0,0 +1,160 @@ +"""Pretty-printing of semantic-layer operation chains. + +Used by every ``Semantic*Op.__repr__``. The Op classes themselves are +imported lazily inside each helper to avoid a circular module +dependency with ``_core``. +""" + +from __future__ import annotations + +from ibis.expr.operations.relations import Relation + + +def _collect_chain(op: Relation) -> list[Relation]: + """Walk op.source (or op.left for joins) back to root, return list from root to current.""" + chain = [op] + current = op + while True: + if hasattr(current, "source") and current.source is not None: + chain.append(current.source) + current = current.source + elif hasattr(current, "left") and current.left is not None: + chain.append(current.left) + current = current.left + else: + break + chain.reverse() + return chain + + +def _format_op_summary(op: Relation) -> str: + """Return a one-line summary string for a non-root semantic op.""" + from ._core import ( + SemanticAggregateOp, + SemanticFilterOp, + SemanticGroupByOp, + SemanticIndexOp, + SemanticJoinOp, + SemanticLimitOp, + SemanticMutateOp, + SemanticOrderByOp, + SemanticProjectOp, + SemanticTableOp, + SemanticUnnestOp, + ) + from ._root_models import _find_all_root_models + + cls = type(op).__name__ + + if isinstance(op, SemanticFilterOp): + predicate = object.__getattribute__(op, "predicate") + pred_name = "" + if hasattr(predicate, "__name__"): + pred_name = predicate.__name__ + elif hasattr(predicate, "unwrap"): + unwrapped = predicate.unwrap + if hasattr(unwrapped, "__name__"): + pred_name = unwrapped.__name__ + return f"Filter(λ {pred_name})" + + if isinstance(op, SemanticMutateOp): + post = object.__getattribute__(op, "post") + cols = list(post.keys()) + return f"Mutate({', '.join(cols)})" + + if isinstance(op, SemanticGroupByOp): + keys = object.__getattribute__(op, "keys") + return f"GroupBy({', '.join(keys)})" + + if isinstance(op, SemanticAggregateOp): + aggs = object.__getattribute__(op, "aggs") + agg_names = list(aggs.keys()) + return f"Aggregate({', '.join(agg_names)})" + + if isinstance(op, SemanticOrderByOp): + keys = object.__getattribute__(op, "keys") + key_strs = [k if isinstance(k, str) else repr(k) for k in keys] + return f"OrderBy({', '.join(key_strs)})" + + if isinstance(op, SemanticLimitOp): + return f"Limit({op.n})" + + if isinstance(op, SemanticProjectOp): + fields = object.__getattribute__(op, "fields") + return f"Project({', '.join(fields)})" + + if isinstance(op, SemanticUnnestOp): + column = object.__getattribute__(op, "column") + return f"Unnest({column})" + + if isinstance(op, SemanticJoinOp): + how = object.__getattribute__(op, "how") + right = object.__getattribute__(op, "right") + right_name = "" + if isinstance(right, SemanticTableOp): + right_name = object.__getattribute__(right, "name") or "" + if not right_name: + # Try to find a root name from right side + roots = _find_all_root_models(right) + if roots: + right_name = object.__getattribute__(roots[0], "name") or "" + if right_name: + return f"Join({how}, right={right_name})" + return f"Join({how})" + + if isinstance(op, SemanticIndexOp): + parts = [] + selector = object.__getattribute__(op, "selector") + by = object.__getattribute__(op, "by") + sample = object.__getattribute__(op, "sample") + if selector is not None: + if isinstance(selector, tuple): + parts.append(", ".join(selector)) + else: + parts.append(str(selector)) + if by is not None: + parts.append(f"by={by}") + if sample is not None: + parts.append(f"sample={sample}") + return f"Index({', '.join(parts)})" + + # Fallback for unknown op types + return cls.replace("Semantic", "").replace("Op", "") + + +def _format_root(root_op) -> str: + """Format a SemanticTableOp root using the fmt registry from format.py.""" + from boring_semantic_layer.format import fmt + + try: + return fmt(root_op) + except Exception: + # Fallback if format module isn't available + name = object.__getattribute__(root_op, "name") + return f"SemanticTable: {name}" if name else "SemanticTable" + + +def _semantic_repr(op: Relation) -> str: + from ._core import SemanticTableOp + + chain = _collect_chain(op) + + # Find the root (first element should be a SemanticTableOp) + root = chain[0] + if isinstance(root, SemanticTableOp): + lines = [_format_root(root)] + else: + # Fallback: no SemanticTableOp root found + from ibis.expr.format import pretty + + try: + return pretty(op) + except Exception: + return object.__repr__(op) + + # Append pipeline steps + for step in chain[1:]: + if not isinstance(step, SemanticTableOp): + lines.append(f"-> {_format_op_summary(step)}") + + return "\n".join(lines) diff --git a/src/boring_semantic_layer/ops/_index_op.py b/src/boring_semantic_layer/ops/_index_op.py new file mode 100644 index 0000000..5013104 --- /dev/null +++ b/src/boring_semantic_layer/ops/_index_op.py @@ -0,0 +1,321 @@ +"""``SemanticIndexOp`` and its index-fragment helpers. + +``index()`` builds a tall ``(fieldName, fieldPath, fieldType, fieldValue, +weight)`` table for each indexed dimension. String fields produce one row +per distinct value; numeric fields collapse to a single ``min .. max`` +row. The result is a ``UNION ALL`` of these per-field fragments. +""" + +from __future__ import annotations + +from collections.abc import Callable +from functools import reduce +from typing import TYPE_CHECKING, Any + +import ibis +from ibis.expr import types as ir +from ibis.expr.operations.relations import Relation + +from .._xorq import FrozenOrderedDict, Schema, selectors as s +from ._format import _semantic_repr +from ._root_models import _find_all_root_models +from ._xorq_compat import _rebind_to_canonical_backend + +if TYPE_CHECKING: + from ..expr import SemanticFilter, SemanticLimit, SemanticOrderBy + + +def _get_field_type_str(field_type: Any) -> str: + return ( + "string" + if field_type.is_string() + else "number" + if field_type.is_numeric() + else "date" + if field_type.is_temporal() + else str(field_type) + ) + + +def _get_weight_expr( + base_tbl: Any, + by_measure: str | None, + all_roots: list, + is_string: bool, +) -> Any: + from .._xorq import api as xo + from ._core import _get_merged_fields + + if not by_measure: + return xo._.count() + + merged_measures = _get_merged_fields(all_roots, "measures") + return ( + merged_measures[by_measure](base_tbl) if by_measure in merged_measures else xo._.count() + ) + + +def _build_string_index_fragment( + base_tbl: Any, + field_expr: Any, + field_name: str, + field_path: str, + type_str: str, + weight_expr: Any, +) -> Any: + from .._xorq import api as xo + + return ( + base_tbl.group_by(field_expr.name("value")) + .aggregate(weight=weight_expr) + .select( + fieldName=xo.literal(field_name.split(".")[-1]), + fieldPath=xo.literal(field_path), + fieldType=xo.literal(type_str), + fieldValue=xo._["value"].cast("string"), + weight=xo._["weight"], + ) + ) + + +def _build_numeric_index_fragment( + base_tbl: Any, + field_expr: Any, + field_name: str, + field_path: str, + type_str: str, + weight_expr: Any, +) -> Any: + from .._xorq import api as xo + + return ( + base_tbl.select(field_expr.name("value")) + .filter(xo._["value"].notnull()) + .aggregate( + min_val=xo._["value"].min(), + max_val=xo._["value"].max(), + weight=weight_expr, + ) + .select( + fieldName=xo.literal(field_name.split(".")[-1]), + fieldPath=xo.literal(field_path), + fieldType=xo.literal(type_str), + fieldValue=( + xo._["min_val"].cast("string") + " to " + xo._["max_val"].cast("string") + ), + weight=xo._["weight"], + ) + ) + + +def _resolve_selector( + selector: str | list[str] | Callable | None, + base_tbl: ir.Table, +) -> tuple[str, ...]: + if selector is None: + return tuple(base_tbl.columns) + try: + selected = base_tbl.select(selector) + return tuple(selected.columns) + except Exception: + return [] + + +def _get_fields_to_index( + selector: str | list[str] | Callable | None, + merged_dimensions: dict, + base_tbl: ir.Table, +) -> tuple[str, ...]: + if selector is None: + selector = s.all() + + raw_fields = _resolve_selector(selector, base_tbl) + + if not raw_fields: + result = list(merged_dimensions.keys()) + result.extend(col for col in base_tbl.columns if col not in result) + else: + result = [col for col in raw_fields if col in merged_dimensions or col in base_tbl.columns] + + return result + + +class SemanticIndexOp(Relation): + source: Relation + selector: str | list[str] | tuple[str, ...] | Callable | None + by: str | None = None + sample: int | None = None + + def __init__( + self, + source: Relation, + selector: str | list[str] | tuple[str, ...] | Callable | None = None, + by: str | None = None, + sample: int | None = None, + ) -> None: + from ._core import _get_merged_fields + + # Validate sample parameter + if sample is not None and sample <= 0: + raise ValueError(f"sample must be positive, got {sample}") + + # Validate 'by' measure exists if provided + if by is not None: + all_roots = _find_all_root_models(source) + if all_roots: + merged_measures = _get_merged_fields(all_roots, "measures") + if by not in merged_measures: + available = list(merged_measures.keys()) + raise KeyError( + f"Unknown measure '{by}' for weight calculation. " + f"Available measures: {', '.join(available) or 'none'}", + ) + + # Convert selector to tuple if it's a list (Ibis requires hashable types) + hashable_selector = tuple(selector) if isinstance(selector, list) else selector + + super().__init__( + source=Relation.__coerce__(source), + selector=hashable_selector, + by=by, + sample=sample, + ) + + def __repr__(self) -> str: + return _semantic_repr(self) + + @property + def values(self) -> FrozenOrderedDict[str, Any]: + from .._xorq import api as xo + + return FrozenOrderedDict( + { + "fieldName": xo.literal("").op(), + "fieldPath": xo.literal("").op(), + "fieldType": xo.literal("").op(), + "fieldValue": xo.literal("").op(), + "weight": xo.literal(0).op(), + }, + ) + + @property + def schema(self) -> Schema: + return Schema( + { + "fieldName": "string", + "fieldPath": "string", + "fieldType": "string", + "fieldValue": "string", + "weight": "int64", + }, + ) + + @property + def keys(self) -> tuple[str, ...]: + return ("fieldValue", "fieldName", "fieldPath", "fieldType") + + @property + def aggs(self) -> dict[str, Any]: + return {"weight": lambda t: t.weight} + + def to_untagged(self): + from ._core import _get_merged_fields, _to_untagged + + all_roots = _find_all_root_models(self.source) + base_tbl = ( + _to_untagged(self.source).limit(self.sample) + if self.sample + else _to_untagged(self.source) + ) + + merged_dimensions = _get_merged_fields(all_roots, "dimensions") + fields_to_index = _get_fields_to_index( + self.selector, + merged_dimensions, + base_tbl, + ) + + if not fields_to_index: + from .._xorq import api as xo + + return xo.memtable( + { + "fieldName": [], + "fieldPath": [], + "fieldType": [], + "fieldValue": [], + "weight": [], + }, + ) + + def build_fragment(field_name: str) -> Any: + field_expr = ( + merged_dimensions[field_name](base_tbl) + if field_name in merged_dimensions + else base_tbl[field_name] + ) + field_type = field_expr.type() + type_str = _get_field_type_str(field_type) + weight_expr = _get_weight_expr( + base_tbl, + self.by, + all_roots, + field_type.is_string(), + ) + + return ( + _build_string_index_fragment( + base_tbl, + field_expr, + field_name, + field_name, + type_str, + weight_expr, + ) + if field_type.is_string() or not field_type.is_numeric() + else _build_numeric_index_fragment( + base_tbl, + field_expr, + field_name, + field_name, + type_str, + weight_expr, + ) + ) + + fragments = [build_fragment(f) for f in fields_to_index] + return reduce(lambda acc, frag: acc.union(frag), fragments[1:], fragments[0]) + + def filter(self, predicate: Callable) -> "SemanticFilter": + from ..expr import SemanticFilter + + return SemanticFilter(source=self, predicate=predicate) + + def order_by(self, *keys: str | ir.Value | Callable) -> "SemanticOrderBy": + from ..expr import SemanticOrderBy + + return SemanticOrderBy(source=self, keys=keys) + + def limit(self, n: int, offset: int = 0) -> "SemanticLimit": + from ..expr import SemanticLimit + + return SemanticLimit(source=self, n=n, offset=offset) + + def execute(self): + return _rebind_to_canonical_backend(self.to_untagged()).execute() + + def as_expr(self): + """Return self as expression.""" + return self + + def compile(self, **kwargs): + return self.to_untagged().compile(**kwargs) + + def sql(self, **kwargs): + return ibis.to_sql(self.to_untagged(), **kwargs) + + def __getitem__(self, key): + return self.to_untagged()[key] + + def pipe(self, func, *args, **kwargs): + return func(self, *args, **kwargs) From 51706395611584b5f10c94f8fce659686692c33e Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Mon, 4 May 2026 07:47:48 -0400 Subject: [PATCH 4/6] refactor(ops): extract OrderBy/Limit and Mutate/Unnest ops Two more pairs of small "transparent" Op classes pulled out: - _order_limit.py (122 lines): SemanticOrderByOp + SemanticLimitOp. Both delegate every metadata accessor to self.source and only override to_untagged() and __repr__. - _mutate_unnest.py (167 lines): SemanticMutateOp + SemanticUnnestOp. Same pass-through pattern; struct unpacking lives in Unnest's to_untagged. Helpers used inside to_untagged (_to_untagged, _resolve_expr, _unwrap) are imported lazily to keep the modules independent of _core import order. _core.py now 3,551 lines. 930 tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/boring_semantic_layer/ops/_core.py | 228 +----------------- .../ops/_mutate_unnest.py | 161 +++++++++++++ src/boring_semantic_layer/ops/_order_limit.py | 118 +++++++++ 3 files changed, 281 insertions(+), 226 deletions(-) create mode 100644 src/boring_semantic_layer/ops/_mutate_unnest.py create mode 100644 src/boring_semantic_layer/ops/_order_limit.py diff --git a/src/boring_semantic_layer/ops/_core.py b/src/boring_semantic_layer/ops/_core.py index af90eca..a54083c 100644 --- a/src/boring_semantic_layer/ops/_core.py +++ b/src/boring_semantic_layer/ops/_core.py @@ -2644,141 +2644,7 @@ def _apply_calc_specs(result, plan, tbl): return out.mutate(**calc_cols) -class SemanticMutateOp(Relation): - source: Relation - post: dict[ - str, - Callable, - ] # Transformed to FrozenDict[str, _CallableWrapper] in __init__ - nested_columns: tuple[ - str, - ..., - ] = () # Inherited from source if it has nested columns - - def __init__( - self, - source: Relation, - post: dict[str, Callable] | None, - nested_columns: tuple[str, ...] = (), - ) -> None: - frozen_post = FrozenDict( - {name: _ensure_wrapped(fn) for name, fn in (post or {}).items()}, - ) - source_nested = nested_columns if nested_columns else getattr(source, "nested_columns", ()) - - super().__init__( - source=Relation.__coerce__(source), - post=frozen_post, - nested_columns=source_nested, - ) - - def __repr__(self) -> str: - return _semantic_repr(self) - - @property - def values(self) -> FrozenOrderedDict[str, Any]: - return self.source.values - - @property - def schema(self) -> Schema: - return self.source.schema - - def to_untagged(self): - agg_tbl = _to_untagged(self.source) - - # Process mutations incrementally so each can reference previous ones - # This allows: .mutate(rank=..., is_other=lambda t: t["rank"] > 5) - current_tbl = agg_tbl - for name, fn_wrapped in self.post.items(): - proxy = MeasureScope(_tbl=current_tbl, _known=[], _post_agg=True) - resolved = _resolve_expr(_unwrap(fn_wrapped), proxy) - - new_col = resolved.name(name) - current_tbl = current_tbl.mutate([new_col]) - - return current_tbl - - def get_dimensions(self) -> Mapping[str, Dimension]: - """Get dictionary of dimensions from source.""" - return self.source.get_dimensions() - - def get_measures(self) -> Mapping[str, Measure]: - """Get dictionary of measures from source.""" - return self.source.get_measures() - - def get_calculated_measures(self) -> Mapping[str, Any]: - """Get dictionary of calculated measures from source.""" - return self.source.get_calculated_measures() - - -class SemanticUnnestOp(Relation): - """Unnest an array column, expanding rows (like Malloy's nested data pattern).""" - - source: Relation - column: str - - def __repr__(self) -> str: - return _semantic_repr(self) - - @property - def schema(self) -> Schema: - # After unnesting, the schema changes - the array column is replaced by its element schema - # For now, delegate to source schema (ideally we'd update it) - return self.source.schema - - @property - def values(self) -> FrozenDict: - return FrozenDict({}) - - def to_untagged(self): - """Convert to Ibis expression with functional struct unpacking. - - Uses pure helper functions to extract struct fields when unnesting - produces struct columns that need to be expanded. - """ - - def build_struct_fields(col_expr, col_type): - """Pure function: build dict of struct field selections.""" - return {name: col_expr[name] for name in col_type.names} - - def unpack_struct_if_needed(unnested_tbl, column_name): - """Conditionally unpack struct fields into top-level columns.""" - if column_name not in unnested_tbl.columns: - return unnested_tbl - - col_expr = unnested_tbl[column_name] - col_type = col_expr.type() - - # Only Struct types have fields to unpack - if isinstance(col_type, dt.Struct) and col_type.fields: - struct_fields = build_struct_fields(col_expr, col_type) - return unnested_tbl.select(unnested_tbl, **struct_fields) - - return unnested_tbl - - tbl = _to_untagged(self.source) - - if self.column not in tbl.columns: - raise ValueError(f"Column '{self.column}' not found in table") - - try: - unnested = tbl.unnest(self.column) - except Exception as e: - raise ValueError(f"Failed to unnest column '{self.column}': {e}") from e - - return unpack_struct_if_needed(unnested, self.column) - - def get_dimensions(self) -> Mapping[str, Dimension]: - """Get dictionary of dimensions from source.""" - return self.source.get_dimensions() - - def get_measures(self) -> Mapping[str, Measure]: - """Get dictionary of measures from source.""" - return self.source.get_measures() - - def get_calculated_measures(self) -> Mapping[str, Any]: - """Get dictionary of calculated measures from source.""" - return self.source.get_calculated_measures() +from ._mutate_unnest import SemanticMutateOp, SemanticUnnestOp # noqa: E402, F401 class SemanticJoinOp(Relation): @@ -3642,97 +3508,7 @@ def as_table(self) -> SemanticTable: ) -class SemanticOrderByOp(Relation): - source: Relation - keys: tuple[ - str | ir.Value | Callable, - ..., - ] # Transformed to tuple[str | _CallableWrapper, ...] in __init__ - - def __init__(self, source: Relation, keys: Iterable[str | ir.Value | Callable]) -> None: - def wrap_key(k): - return k if isinstance(k, str | _CallableWrapper) else _ensure_wrapped(k) - - super().__init__( - source=Relation.__coerce__(source), - keys=tuple(wrap_key(k) for k in keys), - ) - - def __repr__(self) -> str: - return _semantic_repr(self) - - @property - def values(self) -> FrozenOrderedDict[str, Any]: - return self.source.values - - @property - def schema(self) -> Schema: - return self.source.schema - - def to_untagged(self): - tbl = _to_untagged(self.source) - - def resolve_order_key(key): - if isinstance(key, str): - return tbl[key] if key in tbl.columns else getattr(tbl, key, key) - elif isinstance(key, _CallableWrapper): - unwrapped = _unwrap(key) - return _resolve_expr(unwrapped, tbl) - return key - - return tbl.order_by([resolve_order_key(key) for key in self.keys]) - - def get_dimensions(self) -> Mapping[str, Dimension]: - """Get dictionary of dimensions from source.""" - return self.source.get_dimensions() - - def get_measures(self) -> Mapping[str, Measure]: - """Get dictionary of measures from source.""" - return self.source.get_measures() - - def get_calculated_measures(self) -> Mapping[str, Any]: - """Get dictionary of calculated measures from source.""" - return self.source.get_calculated_measures() - - -class SemanticLimitOp(Relation): - source: Relation - n: int - offset: int - - def __init__(self, source: Relation, n: int, offset: int = 0) -> None: - if n <= 0: - raise ValueError(f"limit must be positive, got {n}") - if offset < 0: - raise ValueError(f"offset must be non-negative, got {offset}") - super().__init__(source=Relation.__coerce__(source), n=n, offset=offset) - - def __repr__(self) -> str: - return _semantic_repr(self) - - @property - def values(self) -> FrozenOrderedDict[str, Any]: - return self.source.values - - @property - def schema(self) -> Schema: - return self.source.schema - - def to_untagged(self): - tbl = _to_untagged(self.source) - return tbl.limit(self.n) if self.offset == 0 else tbl.limit(self.n, offset=self.offset) - - def get_dimensions(self) -> Mapping[str, Dimension]: - """Get dictionary of dimensions from source.""" - return self.source.get_dimensions() - - def get_measures(self) -> Mapping[str, Measure]: - """Get dictionary of measures from source.""" - return self.source.get_measures() - - def get_calculated_measures(self) -> Mapping[str, Any]: - """Get dictionary of calculated measures from source.""" - return self.source.get_calculated_measures() +from ._order_limit import SemanticLimitOp, SemanticOrderByOp # noqa: E402, F401 from ._index_op import ( # noqa: E402 diff --git a/src/boring_semantic_layer/ops/_mutate_unnest.py b/src/boring_semantic_layer/ops/_mutate_unnest.py new file mode 100644 index 0000000..7f87576 --- /dev/null +++ b/src/boring_semantic_layer/ops/_mutate_unnest.py @@ -0,0 +1,161 @@ +"""``SemanticMutateOp`` and ``SemanticUnnestOp``. + +Mutate adds derived columns to an aggregated/projected table; Unnest +explodes an array column into one row per element. Both pass through +dimension/measure metadata to their source. +""" + +from __future__ import annotations + +from collections.abc import Callable, Mapping +from typing import Any + +from ibis.expr import datatypes as dt +from ibis.expr.operations.relations import Relation +from ibis.expr.schema import Schema + +from .._xorq import FrozenDict, FrozenOrderedDict +from ..measure_scope import MeasureScope +from ._callable import _ensure_wrapped +from ._format import _semantic_repr +from ._values import Dimension, Measure + + +class SemanticMutateOp(Relation): + source: Relation + post: dict[ + str, + Callable, + ] # Transformed to FrozenDict[str, _CallableWrapper] in __init__ + nested_columns: tuple[ + str, + ..., + ] = () # Inherited from source if it has nested columns + + def __init__( + self, + source: Relation, + post: dict[str, Callable] | None, + nested_columns: tuple[str, ...] = (), + ) -> None: + frozen_post = FrozenDict( + {name: _ensure_wrapped(fn) for name, fn in (post or {}).items()}, + ) + source_nested = nested_columns if nested_columns else getattr(source, "nested_columns", ()) + + super().__init__( + source=Relation.__coerce__(source), + post=frozen_post, + nested_columns=source_nested, + ) + + def __repr__(self) -> str: + return _semantic_repr(self) + + @property + def values(self) -> FrozenOrderedDict[str, Any]: + return self.source.values + + @property + def schema(self) -> Schema: + return self.source.schema + + def to_untagged(self): + from ._core import _resolve_expr, _to_untagged, _unwrap + + agg_tbl = _to_untagged(self.source) + + # Process mutations incrementally so each can reference previous ones + # This allows: .mutate(rank=..., is_other=lambda t: t["rank"] > 5) + current_tbl = agg_tbl + for name, fn_wrapped in self.post.items(): + proxy = MeasureScope(_tbl=current_tbl, _known=[], _post_agg=True) + resolved = _resolve_expr(_unwrap(fn_wrapped), proxy) + + new_col = resolved.name(name) + current_tbl = current_tbl.mutate([new_col]) + + return current_tbl + + def get_dimensions(self) -> Mapping[str, Dimension]: + """Get dictionary of dimensions from source.""" + return self.source.get_dimensions() + + def get_measures(self) -> Mapping[str, Measure]: + """Get dictionary of measures from source.""" + return self.source.get_measures() + + def get_calculated_measures(self) -> Mapping[str, Any]: + """Get dictionary of calculated measures from source.""" + return self.source.get_calculated_measures() + + +class SemanticUnnestOp(Relation): + """Unnest an array column, expanding rows (like Malloy's nested data pattern).""" + + source: Relation + column: str + + def __repr__(self) -> str: + return _semantic_repr(self) + + @property + def schema(self) -> Schema: + # After unnesting, the schema changes - the array column is replaced by its element schema + # For now, delegate to source schema (ideally we'd update it) + return self.source.schema + + @property + def values(self) -> FrozenDict: + return FrozenDict({}) + + def to_untagged(self): + """Convert to Ibis expression with functional struct unpacking. + + Uses pure helper functions to extract struct fields when unnesting + produces struct columns that need to be expanded. + """ + from ._core import _to_untagged + + def build_struct_fields(col_expr, col_type): + """Pure function: build dict of struct field selections.""" + return {name: col_expr[name] for name in col_type.names} + + def unpack_struct_if_needed(unnested_tbl, column_name): + """Conditionally unpack struct fields into top-level columns.""" + if column_name not in unnested_tbl.columns: + return unnested_tbl + + col_expr = unnested_tbl[column_name] + col_type = col_expr.type() + + # Only Struct types have fields to unpack + if isinstance(col_type, dt.Struct) and col_type.fields: + struct_fields = build_struct_fields(col_expr, col_type) + return unnested_tbl.select(unnested_tbl, **struct_fields) + + return unnested_tbl + + tbl = _to_untagged(self.source) + + if self.column not in tbl.columns: + raise ValueError(f"Column '{self.column}' not found in table") + + try: + unnested = tbl.unnest(self.column) + except Exception as e: + raise ValueError(f"Failed to unnest column '{self.column}': {e}") from e + + return unpack_struct_if_needed(unnested, self.column) + + def get_dimensions(self) -> Mapping[str, Dimension]: + """Get dictionary of dimensions from source.""" + return self.source.get_dimensions() + + def get_measures(self) -> Mapping[str, Measure]: + """Get dictionary of measures from source.""" + return self.source.get_measures() + + def get_calculated_measures(self) -> Mapping[str, Any]: + """Get dictionary of calculated measures from source.""" + return self.source.get_calculated_measures() diff --git a/src/boring_semantic_layer/ops/_order_limit.py b/src/boring_semantic_layer/ops/_order_limit.py new file mode 100644 index 0000000..8bda048 --- /dev/null +++ b/src/boring_semantic_layer/ops/_order_limit.py @@ -0,0 +1,118 @@ +"""``SemanticOrderByOp`` and ``SemanticLimitOp`` — terminal pass-through ops. + +Both are thin wrappers that delegate dimension/measure metadata to their +source and only modify ``to_untagged()`` output. They share the same +shape: pass-through ``schema``/``values``, override ``__repr__`` and +``to_untagged``, and forward ``get_*`` accessors to ``self.source``. +""" + +from __future__ import annotations + +from collections.abc import Callable, Iterable, Mapping +from typing import Any + +from ibis.expr import types as ir +from ibis.expr.operations.relations import Relation +from ibis.expr.schema import Schema + +from .._xorq import FrozenOrderedDict +from ._callable import _CallableWrapper, _ensure_wrapped +from ._format import _semantic_repr +from ._values import Dimension, Measure + + +class SemanticOrderByOp(Relation): + source: Relation + keys: tuple[ + str | ir.Value | Callable, + ..., + ] # Transformed to tuple[str | _CallableWrapper, ...] in __init__ + + def __init__(self, source: Relation, keys: Iterable[str | ir.Value | Callable]) -> None: + def wrap_key(k): + return k if isinstance(k, str | _CallableWrapper) else _ensure_wrapped(k) + + super().__init__( + source=Relation.__coerce__(source), + keys=tuple(wrap_key(k) for k in keys), + ) + + def __repr__(self) -> str: + return _semantic_repr(self) + + @property + def values(self) -> FrozenOrderedDict[str, Any]: + return self.source.values + + @property + def schema(self) -> Schema: + return self.source.schema + + def to_untagged(self): + from ._core import _resolve_expr, _to_untagged, _unwrap + + tbl = _to_untagged(self.source) + + def resolve_order_key(key): + if isinstance(key, str): + return tbl[key] if key in tbl.columns else getattr(tbl, key, key) + elif isinstance(key, _CallableWrapper): + unwrapped = _unwrap(key) + return _resolve_expr(unwrapped, tbl) + return key + + return tbl.order_by([resolve_order_key(key) for key in self.keys]) + + def get_dimensions(self) -> Mapping[str, Dimension]: + """Get dictionary of dimensions from source.""" + return self.source.get_dimensions() + + def get_measures(self) -> Mapping[str, Measure]: + """Get dictionary of measures from source.""" + return self.source.get_measures() + + def get_calculated_measures(self) -> Mapping[str, Any]: + """Get dictionary of calculated measures from source.""" + return self.source.get_calculated_measures() + + +class SemanticLimitOp(Relation): + source: Relation + n: int + offset: int + + def __init__(self, source: Relation, n: int, offset: int = 0) -> None: + if n <= 0: + raise ValueError(f"limit must be positive, got {n}") + if offset < 0: + raise ValueError(f"offset must be non-negative, got {offset}") + super().__init__(source=Relation.__coerce__(source), n=n, offset=offset) + + def __repr__(self) -> str: + return _semantic_repr(self) + + @property + def values(self) -> FrozenOrderedDict[str, Any]: + return self.source.values + + @property + def schema(self) -> Schema: + return self.source.schema + + def to_untagged(self): + from ._core import _to_untagged + + tbl = _to_untagged(self.source) + return tbl.limit(self.n) if self.offset == 0 else tbl.limit(self.n, offset=self.offset) + + def get_dimensions(self) -> Mapping[str, Dimension]: + """Get dictionary of dimensions from source.""" + return self.source.get_dimensions() + + def get_measures(self) -> Mapping[str, Measure]: + """Get dictionary of measures from source.""" + return self.source.get_measures() + + def get_calculated_measures(self) -> Mapping[str, Any]: + """Get dictionary of calculated measures from source.""" + return self.source.get_calculated_measures() From a54714a388c42d6a2b4ae0b0d7cca18ee9847215 Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Mon, 4 May 2026 07:51:01 -0400 Subject: [PATCH 5/6] refactor(ops): extract measure-classification & column-error helpers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New _measure_helpers.py (~280 lines): - _extract_measure_metadata: read user-supplied measure forms (dict/Measure/raw) into a uniform tuple - _is_calculated_measure / _matches_aggregation_pattern / _find_matching_measure: detect calc measures vs base measures and match aggregation patterns against known measure definitions - _make_base_measure: wrap a callable/Deferred/AggregationExpr into a Measure - _classify_measure: top-level dispatch — used externally by expr.py - _build_json_definition: dim/measure JSON-export helper - _format_column_error: friendly diagnostic for missing-column errors _values.py's lazy import of _format_column_error switches to the new home so the value-objects → core back-ref is gone. _core.py is now 3,283 lines (down from 3,551). 930 tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/boring_semantic_layer/ops/_core.py | 290 +--------------- .../ops/_measure_helpers.py | 314 ++++++++++++++++++ src/boring_semantic_layer/ops/_values.py | 2 +- 3 files changed, 326 insertions(+), 280 deletions(-) create mode 100644 src/boring_semantic_layer/ops/_measure_helpers.py diff --git a/src/boring_semantic_layer/ops/_core.py b/src/boring_semantic_layer/ops/_core.py index a54083c..654922c 100644 --- a/src/boring_semantic_layer/ops/_core.py +++ b/src/boring_semantic_layer/ops/_core.py @@ -277,285 +277,17 @@ def _classify_dependencies( from ._callable import _CallableWrapper, _ensure_wrapped, _infer_unnest # noqa: E402 -def _extract_measure_metadata( - fn_or_expr: Any, -) -> tuple[Any, str | None, tuple, Mapping[str, Any]]: - """Extract metadata from various measure representations.""" - if isinstance(fn_or_expr, dict): - return ( - fn_or_expr["expr"], - fn_or_expr.get("description"), - tuple(fn_or_expr.get("requires_unnest", [])), - dict(fn_or_expr.get("metadata") or {}), - ) - elif isinstance(fn_or_expr, Measure): - return ( - fn_or_expr.expr, - fn_or_expr.description, - fn_or_expr.requires_unnest, - dict(fn_or_expr.metadata), - ) - else: - return (fn_or_expr, None, (), {}) - - -_AGG_METHODS = frozenset({"sum", "mean", "avg", "count", "min", "max"}) - - -def _is_calculated_measure(val: Any) -> bool: - # A MethodCall with an aggregation method on a MeasureRef is a base measure: - # the column name matched a known measure name in MeasureScope, but the user - # is really defining a column aggregation (e.g. lambda t: t.flight_count.sum()). - if ( - isinstance(val, MethodCall) - and val.method in _AGG_METHODS - and isinstance(val.receiver, MeasureRef) - ): - return False - return isinstance(val, MeasureRef | AllOf | BinOp | MethodCall | int | float) - - -def _matches_aggregation_pattern(measure_expr, agg_expr, tbl): - if not isinstance(agg_expr, AggregationExpr): - return Success(False) - - @curry - def evaluate_in_scope(tbl, expr): - """Evaluate measure expression in a ColumnScope.""" - scope = ColumnScope(_tbl=tbl) - return ( - expr.resolve(scope) if _is_deferred(expr) else expr(scope) if callable(expr) else expr - ) - - @curry - def has_matching_operation(agg_expr, result): - """Check if the operation matches the expected aggregation. - - All our supported aggregations (Sum, Mean, Count, Min, Max) are ibis operations. - """ - op_name = type(result.op()).__name__.lower() - expected_op = "avg" if agg_expr.operation.lower() == "mean" else agg_expr.operation.lower() - - return expected_op in op_name - - @curry - def has_matching_column(agg_expr, result): - """Check if result's operation references the expected column. - - All supported aggregation operations (Sum, Mean, Count, Min, Max) have: - - args[0]: Field operation with .name attribute - - args[1]: Optional where clause (typically None) - """ - op = result.op() - - if not isinstance(op.args[0], Field): - return False - - return op.args[0].name == agg_expr.column - - def matches_pattern(result): - """Check if result matches both operation and column.""" - return has_matching_operation(agg_expr, result) and has_matching_column(agg_expr, result) - - return safe(lambda: evaluate_in_scope(tbl, measure_expr))().map(matches_pattern) - - -def _find_matching_measure(agg_expr, known_measures: dict, tbl): - """Find a measure that matches the aggregation expression pattern. - - Returns Maybe[str] using functional patterns. - """ - if not isinstance(agg_expr, AggregationExpr): - return Nothing - - @curry - def matches_pattern(agg_expr, tbl, measure_obj): - """Check if measure matches the aggregation pattern. - - All measure_obj values are Measure instances with an expr attribute. - """ - result = _matches_aggregation_pattern(measure_obj.expr, agg_expr, tbl) - return result.value_or(False) - - for measure_name, measure_obj in known_measures.items(): - if matches_pattern(agg_expr, tbl, measure_obj): - return Some(measure_name) - - return Nothing - - -def _make_base_measure( - expr: Any, - description: str | None, - requires_unnest: tuple, - metadata: Mapping[str, Any] | None = None, -) -> Measure: - """Create a base measure with proper callable wrapping using functional patterns.""" - - @curry - def apply_aggregation(operation: str, column): - """Apply aggregation operation to a column using functional dispatch.""" - operations = { - "sum": lambda c: c.sum(), - "mean": lambda c: c.mean(), - "avg": lambda c: c.mean(), - "count": lambda c: c.count(), - "min": lambda c: c.min(), - "max": lambda c: c.max(), - } - - return ( - Maybe.from_optional(operations.get(operation)) - .map(lambda fn: fn(column)) - .value_or( - (_ for _ in ()).throw(ValueError(f"Unknown aggregation operation: {operation}")) - ) - ) - - @curry - def evaluate_expr(expr, scope): - """Evaluate expression in given scope.""" - return ( - expr.resolve(scope) if _is_deferred(expr) else expr(scope) if callable(expr) else expr - ) - - def convert_aggregation_expr(t, agg_expr: AggregationExpr): - """Convert AggregationExpr to ibis expression.""" - if agg_expr.operation == "count": - result = t.count() - else: - result = apply_aggregation(agg_expr.operation, t[agg_expr.column]) - - for method_name, args, kwargs_tuple in agg_expr.post_ops: - result = getattr(result, method_name)(*args, **dict(kwargs_tuple)) - - return result - - raw_expr = expr._fn if isinstance(expr, _CallableWrapper) else expr - - if isinstance(expr, AggregationExpr): - - def wrapped_expr(t): - """Convert AggregationExpr to ibis expression.""" - return convert_aggregation_expr(t, expr) - - return Measure( - expr=wrapped_expr, - description=description, - requires_unnest=requires_unnest, - original_expr=raw_expr, - metadata=dict(metadata or {}), - ) - - if callable(expr): - - def wrapped_expr(t): - """Wrapped expression that handles AggregationExpr conversion.""" - scope = ColumnScope(_tbl=t) - result = evaluate_expr(expr, scope) - - if isinstance(result, AggregationExpr): - return convert_aggregation_expr(t, result) - return result - - return Measure( - expr=wrapped_expr, - description=description, - requires_unnest=requires_unnest, - original_expr=raw_expr, - metadata=dict(metadata or {}), - ) - else: - return Measure( - expr=lambda t, fn=expr: evaluate_expr(fn, ColumnScope(_tbl=t)), - description=description, - requires_unnest=requires_unnest, - original_expr=raw_expr, - metadata=dict(metadata or {}), - ) - - -def _classify_measure( - fn_or_expr: Any, scope: Any, measure_name: str | None = None -) -> tuple[str, Any]: - """Classify measure as 'calc' or 'base' with appropriate handling.""" - from ..measure_scope import validate_calc_ast - - expr, description, requires_unnest, metadata = _extract_measure_metadata(fn_or_expr) - - resolved = safe(lambda: _resolve_expr(expr, scope))().map( - lambda val: ("calc", val) if _is_calculated_measure(val) else None - ) - - if isinstance(resolved, Success) and resolved.unwrap() is not None: - kind, value = resolved.unwrap() - validate_calc_ast(value, measure_name) - return (kind, value) - - if not requires_unnest and callable(expr): - # All scopes (MeasureScope, ColumnScope) have tbl attribute - table = scope.tbl - inferred_unnest = _infer_unnest(expr, table) - requires_unnest = requires_unnest or inferred_unnest - - return ("base", _make_base_measure(expr, description, requires_unnest, metadata)) - - -def _build_json_definition( - dims_dict: dict, - meas_dict: dict, - name: str | None = None, - description: str | None = None, -) -> dict: - result = { - "dimensions": {n: spec.to_json() for n, spec in dims_dict.items()}, - "measures": {n: spec.to_json() for n, spec in meas_dict.items()}, - "entity_dimensions": {n: spec.to_json() for n, spec in dims_dict.items() if spec.is_entity}, - "event_timestamp": { - n: spec.to_json() for n, spec in dims_dict.items() if spec.is_event_timestamp - }, - "time_dimensions": { - n: spec.to_json() for n, spec in dims_dict.items() if spec.is_time_dimension - }, - "name": name, - } - if description is not None: - result["description"] = description - return result - - -def _format_column_error(e: AttributeError, table: ir.Table) -> str: - """Format a helpful error message for missing column errors.""" - # Extract the column name from the error - match = re.search(r"has no attribute ['\"]([^'\"]+)['\"]", str(e)) - missing_col = match.group(1) if match else "unknown" - - # Get available columns - available_cols = list(table.columns) if hasattr(table, "columns") else [] - - # Build error message - parts = [f"Dimension expression references non-existent column '{missing_col}'."] - - if len(available_cols) > 20: - parts.append(f"Table has {len(available_cols)} columns. First 15: {available_cols[:15]}") - elif available_cols: - parts.append(f"Available columns: {available_cols}") - else: - parts.append(f"No columns available in {type(table).__name__} object") - - # Suggest similar column names - suggestions = get_close_matches(missing_col, available_cols, n=3, cutoff=0.6) - if suggestions: - parts[-1] += f". Did you mean: {suggestions}?" - - # Add helpful tip - example = suggestions[0] if suggestions else "column_name" - parts.append( - f"\n\nTip: Check that your dimension expression uses the correct column name. " - f"For example: lambda t: t.{example}" - ) - - return " ".join(parts) +from ._measure_helpers import ( # noqa: E402 + _AGG_METHODS, + _build_json_definition, + _classify_measure, + _extract_measure_metadata, + _find_matching_measure, + _format_column_error, + _is_calculated_measure, + _make_base_measure, + _matches_aggregation_pattern, +) from ._values import ( # noqa: E402 diff --git a/src/boring_semantic_layer/ops/_measure_helpers.py b/src/boring_semantic_layer/ops/_measure_helpers.py new file mode 100644 index 0000000..f531ba1 --- /dev/null +++ b/src/boring_semantic_layer/ops/_measure_helpers.py @@ -0,0 +1,314 @@ +"""Measure-classification, JSON-definition, and column-error helpers. + +These analyze user-supplied measure expressions to decide whether they're +"base" (apply an aggregation to a column) or "calc" (reference other +measures), wrap them so the same pipeline can execute both, and produce +the human-readable diagnostic when a dimension lambda references a +column that doesn't exist on the table. +""" + +from __future__ import annotations + +import re +from collections.abc import Mapping +from difflib import get_close_matches +from typing import Any + +from ibis.expr import types as ir +from ibis.expr.operations.relations import Field +from returns.maybe import Maybe, Nothing, Some +from returns.result import Success, safe +from toolz import curry + +from ..measure_scope import ( + AggregationExpr, + AllOf, + BinOp, + ColumnScope, + MeasureRef, + MethodCall, +) +from ._callable import _CallableWrapper, _infer_unnest +from ._values import Measure, _is_deferred + + +def _extract_measure_metadata( + fn_or_expr: Any, +) -> tuple[Any, str | None, tuple, Mapping[str, Any]]: + """Extract metadata from various measure representations.""" + if isinstance(fn_or_expr, dict): + return ( + fn_or_expr["expr"], + fn_or_expr.get("description"), + tuple(fn_or_expr.get("requires_unnest", [])), + dict(fn_or_expr.get("metadata") or {}), + ) + elif isinstance(fn_or_expr, Measure): + return ( + fn_or_expr.expr, + fn_or_expr.description, + fn_or_expr.requires_unnest, + dict(fn_or_expr.metadata), + ) + else: + return (fn_or_expr, None, (), {}) + + +_AGG_METHODS = frozenset({"sum", "mean", "avg", "count", "min", "max"}) + + +def _is_calculated_measure(val: Any) -> bool: + # A MethodCall with an aggregation method on a MeasureRef is a base measure: + # the column name matched a known measure name in MeasureScope, but the user + # is really defining a column aggregation (e.g. lambda t: t.flight_count.sum()). + if ( + isinstance(val, MethodCall) + and val.method in _AGG_METHODS + and isinstance(val.receiver, MeasureRef) + ): + return False + return isinstance(val, MeasureRef | AllOf | BinOp | MethodCall | int | float) + + +def _matches_aggregation_pattern(measure_expr, agg_expr, tbl): + if not isinstance(agg_expr, AggregationExpr): + return Success(False) + + @curry + def evaluate_in_scope(tbl, expr): + """Evaluate measure expression in a ColumnScope.""" + scope = ColumnScope(_tbl=tbl) + return ( + expr.resolve(scope) if _is_deferred(expr) else expr(scope) if callable(expr) else expr + ) + + @curry + def has_matching_operation(agg_expr, result): + """Check if the operation matches the expected aggregation. + + All our supported aggregations (Sum, Mean, Count, Min, Max) are ibis operations. + """ + op_name = type(result.op()).__name__.lower() + expected_op = "avg" if agg_expr.operation.lower() == "mean" else agg_expr.operation.lower() + + return expected_op in op_name + + @curry + def has_matching_column(agg_expr, result): + """Check if result's operation references the expected column. + + All supported aggregation operations (Sum, Mean, Count, Min, Max) have: + - args[0]: Field operation with .name attribute + - args[1]: Optional where clause (typically None) + """ + op = result.op() + + if not isinstance(op.args[0], Field): + return False + + return op.args[0].name == agg_expr.column + + def matches_pattern(result): + """Check if result matches both operation and column.""" + return has_matching_operation(agg_expr, result) and has_matching_column(agg_expr, result) + + return safe(lambda: evaluate_in_scope(tbl, measure_expr))().map(matches_pattern) + + +def _find_matching_measure(agg_expr, known_measures: dict, tbl): + """Find a measure that matches the aggregation expression pattern. + + Returns Maybe[str] using functional patterns. + """ + if not isinstance(agg_expr, AggregationExpr): + return Nothing + + @curry + def matches_pattern(agg_expr, tbl, measure_obj): + """Check if measure matches the aggregation pattern. + + All measure_obj values are Measure instances with an expr attribute. + """ + result = _matches_aggregation_pattern(measure_obj.expr, agg_expr, tbl) + return result.value_or(False) + + for measure_name, measure_obj in known_measures.items(): + if matches_pattern(agg_expr, tbl, measure_obj): + return Some(measure_name) + + return Nothing + + +def _make_base_measure( + expr: Any, + description: str | None, + requires_unnest: tuple, + metadata: Mapping[str, Any] | None = None, +) -> Measure: + """Create a base measure with proper callable wrapping using functional patterns.""" + + @curry + def apply_aggregation(operation: str, column): + """Apply aggregation operation to a column using functional dispatch.""" + operations = { + "sum": lambda c: c.sum(), + "mean": lambda c: c.mean(), + "avg": lambda c: c.mean(), + "count": lambda c: c.count(), + "min": lambda c: c.min(), + "max": lambda c: c.max(), + } + + return ( + Maybe.from_optional(operations.get(operation)) + .map(lambda fn: fn(column)) + .value_or( + (_ for _ in ()).throw(ValueError(f"Unknown aggregation operation: {operation}")) + ) + ) + + @curry + def evaluate_expr(expr, scope): + """Evaluate expression in given scope.""" + return ( + expr.resolve(scope) if _is_deferred(expr) else expr(scope) if callable(expr) else expr + ) + + def convert_aggregation_expr(t, agg_expr: AggregationExpr): + """Convert AggregationExpr to ibis expression.""" + if agg_expr.operation == "count": + result = t.count() + else: + result = apply_aggregation(agg_expr.operation, t[agg_expr.column]) + + for method_name, args, kwargs_tuple in agg_expr.post_ops: + result = getattr(result, method_name)(*args, **dict(kwargs_tuple)) + + return result + + raw_expr = expr._fn if isinstance(expr, _CallableWrapper) else expr + + if isinstance(expr, AggregationExpr): + + def wrapped_expr(t): + """Convert AggregationExpr to ibis expression.""" + return convert_aggregation_expr(t, expr) + + return Measure( + expr=wrapped_expr, + description=description, + requires_unnest=requires_unnest, + original_expr=raw_expr, + metadata=dict(metadata or {}), + ) + + if callable(expr): + + def wrapped_expr(t): + """Wrapped expression that handles AggregationExpr conversion.""" + scope = ColumnScope(_tbl=t) + result = evaluate_expr(expr, scope) + + if isinstance(result, AggregationExpr): + return convert_aggregation_expr(t, result) + return result + + return Measure( + expr=wrapped_expr, + description=description, + requires_unnest=requires_unnest, + original_expr=raw_expr, + metadata=dict(metadata or {}), + ) + else: + return Measure( + expr=lambda t, fn=expr: evaluate_expr(fn, ColumnScope(_tbl=t)), + description=description, + requires_unnest=requires_unnest, + original_expr=raw_expr, + metadata=dict(metadata or {}), + ) + + +def _classify_measure( + fn_or_expr: Any, scope: Any, measure_name: str | None = None +) -> tuple[str, Any]: + """Classify measure as 'calc' or 'base' with appropriate handling.""" + from ..measure_scope import validate_calc_ast + from ._core import _resolve_expr + + expr, description, requires_unnest, metadata = _extract_measure_metadata(fn_or_expr) + + resolved = safe(lambda: _resolve_expr(expr, scope))().map( + lambda val: ("calc", val) if _is_calculated_measure(val) else None + ) + + if isinstance(resolved, Success) and resolved.unwrap() is not None: + kind, value = resolved.unwrap() + validate_calc_ast(value, measure_name) + return (kind, value) + + if not requires_unnest and callable(expr): + # All scopes (MeasureScope, ColumnScope) have tbl attribute + table = scope.tbl + inferred_unnest = _infer_unnest(expr, table) + requires_unnest = requires_unnest or inferred_unnest + + return ("base", _make_base_measure(expr, description, requires_unnest, metadata)) + + +def _build_json_definition( + dims_dict: dict, + meas_dict: dict, + name: str | None = None, + description: str | None = None, +) -> dict: + result = { + "dimensions": {n: spec.to_json() for n, spec in dims_dict.items()}, + "measures": {n: spec.to_json() for n, spec in meas_dict.items()}, + "entity_dimensions": {n: spec.to_json() for n, spec in dims_dict.items() if spec.is_entity}, + "event_timestamp": { + n: spec.to_json() for n, spec in dims_dict.items() if spec.is_event_timestamp + }, + "time_dimensions": { + n: spec.to_json() for n, spec in dims_dict.items() if spec.is_time_dimension + }, + "name": name, + } + if description is not None: + result["description"] = description + return result + + +def _format_column_error(e: AttributeError, table: ir.Table) -> str: + """Format a helpful error message for missing column errors.""" + # Extract the column name from the error + match = re.search(r"has no attribute ['\"]([^'\"]+)['\"]", str(e)) + missing_col = match.group(1) if match else "unknown" + + # Get available columns + available_cols = list(table.columns) if hasattr(table, "columns") else [] + + # Build error message + parts = [f"Dimension expression references non-existent column '{missing_col}'."] + + if len(available_cols) > 20: + parts.append(f"Table has {len(available_cols)} columns. First 15: {available_cols[:15]}") + elif available_cols: + parts.append(f"Available columns: {available_cols}") + else: + parts.append(f"No columns available in {type(table).__name__} object") + + # Suggest similar column names + suggestions = get_close_matches(missing_col, available_cols, n=3, cutoff=0.6) + if suggestions: + parts[-1] += f". Did you mean: {suggestions}?" + + # Add helpful tip + example = suggestions[0] if suggestions else "column_name" + parts.append( + f"\n\nTip: Check that your dimension expression uses the correct column name. " + f"For example: lambda t: t.{example}" + ) + + return " ".join(parts) diff --git a/src/boring_semantic_layer/ops/_values.py b/src/boring_semantic_layer/ops/_values.py index b06735a..74fcb12 100644 --- a/src/boring_semantic_layer/ops/_values.py +++ b/src/boring_semantic_layer/ops/_values.py @@ -104,7 +104,7 @@ def __call__(self, table: ir.Table, _dims: dict | None = None) -> ir.Value: if "'Table' object has no attribute" in str( e ) or "'Join' object has no attribute" in str(e): - from ._core import _format_column_error + from ._measure_helpers import _format_column_error raise AttributeError(_format_column_error(e, table)) from e raise From a18105fcbdcf4f665194e3ccba19d9d9c7b00db3 Mon Sep 17 00:00:00 2001 From: Hussain Sultan Date: Mon, 4 May 2026 07:54:50 -0400 Subject: [PATCH 6/6] refactor(ops): extract Table/Filter/Project/GroupBy ops New _basic_ops.py (~430 lines) holds the four foundation Op classes: - SemanticTableOp: leaf relation that carries dimension/measure metadata. Uses lazy imports for _make_schema, _mutate_dimensions_with_dependencies, and the module logger from _core. - SemanticFilterOp: predicate filtering with derived-dimension enrichment. Lazy-imports _Resolver/_get_merged_fields/_resolve_expr/ _to_untagged/_unwrap and SemanticAggregateOp for its post-agg check. - SemanticProjectOp + helpers _classify_fields, _process_nested_access_marker, _evaluate_measures_with_unnesting, _build_select_or_aggregate. - SemanticGroupByOp: thin pass-through. All call sites that used to live in _core re-import from _basic_ops so external `from boring_semantic_layer.ops import X` keeps working. _core.py is now 2,908 lines (down from 5,330 at start of branch). 930 tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/boring_semantic_layer/ops/_basic_ops.py | 436 ++++++++++++++++++++ src/boring_semantic_layer/ops/_core.py | 395 +----------------- 2 files changed, 446 insertions(+), 385 deletions(-) create mode 100644 src/boring_semantic_layer/ops/_basic_ops.py diff --git a/src/boring_semantic_layer/ops/_basic_ops.py b/src/boring_semantic_layer/ops/_basic_ops.py new file mode 100644 index 0000000..a4d6d22 --- /dev/null +++ b/src/boring_semantic_layer/ops/_basic_ops.py @@ -0,0 +1,436 @@ +"""Foundation Op classes: Table, Filter, Project, GroupBy. + +These four are the simpler, mostly-pass-through ops. SemanticTableOp is +the leaf that holds dimension/measure metadata; the others wrap it +(or a downstream op) and contribute one operation each. The aggregate, +join, and order/limit ops live in their own modules. +""" + +from __future__ import annotations + +from collections.abc import Callable, Iterable, Mapping +from typing import Any + +from attrs import field +from ibis.expr import types as ir +from ibis.expr.operations.relations import Relation +from ibis.expr.schema import Schema + +from .._xorq import FrozenDict, FrozenOrderedDict, Schema as XorqSchema +from ..nested_access import NestedAccessMarker +from ._callable import _ensure_wrapped +from ._format import _semantic_repr +from ._measure_helpers import _build_json_definition +from ._root_models import _find_all_root_models +from ._values import Dimension, Measure + + +_SchemaClass = XorqSchema +_FrozenOrderedDict = FrozenOrderedDict + + +class SemanticTableOp(Relation): + """Relation with semantic metadata (dimensions and measures). + + Stores ir.Table expression directly to avoid .op() → .to_expr() conversions. + + Note: Accepts both regular ibis.Table and xorq's vendored ibis.Table. + Regular ibis tables are automatically converted to xorq in __init__. + """ + + table: Any # Accepts both ir.Table and regular ibis.expr.types.Table + dimensions: FrozenDict[str, Dimension] + measures: FrozenDict[str, Measure] + calc_measures: FrozenDict[str, Any] + name: str | None = None + description: str | None = None + _source_join: Any = field( + default=None, repr=False + ) # Track if this wraps a join (SemanticJoinOp) for optimization + + def __init__( + self, + table: ir.Table, + dimensions: dict[str, Dimension] | FrozenDict[str, Dimension], + measures: dict[str, Measure] | FrozenDict[str, Measure], + calc_measures: dict[str, Any] | FrozenDict[str, Any], + name: str | None = None, + description: str | None = None, + _source_join: Any = None, + ) -> None: + # Accept both regular ibis and xorq tables without conversion + # This allows using regular ibis by default, xorq only when provided + super().__init__( + table=table, + dimensions=FrozenDict(dimensions) + if not isinstance(dimensions, FrozenDict) + else dimensions, + measures=FrozenDict(measures) if not isinstance(measures, FrozenDict) else measures, + calc_measures=FrozenDict(calc_measures) + if not isinstance(calc_measures, FrozenDict) + else calc_measures, + name=name, + description=description, + _source_join=_source_join, + ) + + def __repr__(self) -> str: + return _semantic_repr(self) + + @property + def values(self) -> FrozenOrderedDict[str, Any]: + from ._core import _make_schema, _mutate_dimensions_with_dependencies, logger + + dims = self.get_dimensions() + measures = self.get_measures() + calc_measures = self.get_calculated_measures() + # Build enriched table with all dimensions resolved (handles derived deps) + enriched = _mutate_dimensions_with_dependencies(self.table, dims.keys(), dims) + base_values = { + **{col: self.table[col].op() for col in self.table.columns}, + **{name: enriched[name].op() for name in dims}, + **{name: fn(enriched).op() for name, fn in measures.items()}, + } + # Resolve calculated measure types via a dummy table with base measure dtypes. + # ``infer_calc_dtype`` mirrors the AggregationExpr rewrite from + # ``compile_grouped_with_all`` so calc measures with inline aggregations + # (e.g. ``AllOf(AggregationExpr)``) round-trip through type inference. + if calc_measures: + from ..compile_all import _get_ibis_module, infer_calc_dtype + + measure_schema = { + name: base_values[name].dtype for name in measures if name in base_values + } + ibis_module = _get_ibis_module(enriched) + for name, expr in calc_measures.items(): + try: + compiled = infer_calc_dtype( + expr, measure_schema, enriched, ibis_module + ) + base_values[name] = compiled.op() + except Exception as e: + # Joined models with dotted column names, calc measures + # whose inline aggregations don't apply to the dummy schema, + # etc. Type info is best-effort; surface for debugging. + logger.debug( + "calc-measure type inference failed for %r: %s", name, e + ) + return FrozenOrderedDict(base_values) + + @property + def schema(self): + from ._core import _make_schema + + fields_dict = {name: str(v.dtype) for name, v in self.values.items()} + return _make_schema(fields_dict) + + @property + def json_definition(self) -> Mapping[str, Any]: + return _build_json_definition( + self.get_dimensions(), + self.get_measures(), + self.name, + self.description, + ) + + @property + def _dims(self) -> dict[str, Dimension]: + return dict(self.get_dimensions()) + + @property + def _base_measures(self) -> dict[str, Measure]: + return dict(self.get_measures()) + + @property + def _calc_measures(self) -> dict[str, Any]: + return dict(self.get_calculated_measures()) + + def get_measures(self) -> Mapping[str, Measure]: + """Get dictionary of base measures with metadata.""" + return object.__getattribute__(self, "measures") + + def get_dimensions(self) -> Mapping[str, Dimension]: + """Get dictionary of dimensions with metadata.""" + return object.__getattribute__(self, "dimensions") + + def get_calculated_measures(self) -> Mapping[str, Any]: + """Get dictionary of calculated measures with metadata.""" + return self.calc_measures + + def get_graph(self) -> dict[str, dict[str, Any]]: + from ..graph_utils import build_dependency_graph + + return build_dependency_graph( + self.get_dimensions(), + self.get_measures(), + self.get_calculated_measures(), + self.table, + ) + + def __getattribute__(self, name: str): + """Override attribute access to return tuples for dimensions/measures. + + This provides a cleaner API where .dimensions returns ('dim1', 'dim2') + instead of the full FrozenDict. Use get_dimensions() to get the full dict. + """ + # For special/internal attributes (dunder methods), use default behavior + # This is critical for xorq's vendored ibis which uses __precomputed_hash__, etc. + if name.startswith("__") and name.endswith("__"): + return object.__getattribute__(self, name) + + # Custom behavior for dimensions and measures + if name == "dimensions": + dims = object.__getattribute__(self, "dimensions") + return tuple(dims.keys()) + if name == "measures": + base_meas = object.__getattribute__(self, "measures") + calc_meas = object.__getattribute__(self, "calc_measures") + return tuple(base_meas.keys()) + tuple(calc_meas.keys()) + + # Default behavior for everything else + return object.__getattribute__(self, name) + + def to_untagged(self): + # Conversion happens at SemanticModel construction; self.table is + # already xorq when supported, plain ibis when not. + return self.table + + +class SemanticFilterOp(Relation): + source: Relation + predicate: Callable + + def __init__(self, source: Relation, predicate: Callable) -> None: + super().__init__( + source=Relation.__coerce__(source), + predicate=_ensure_wrapped(predicate), + ) + + def __repr__(self) -> str: + return _semantic_repr(self) + + @property + def values(self) -> FrozenOrderedDict[str, Any]: + return self.source.values + + @property + def schema(self) -> Schema: + return self.source.schema + + def to_untagged(self): + from ..convert import _Resolver + from ._core import ( + _get_merged_fields, + _mutate_dimensions_with_dependencies, + _resolve_expr, + _to_untagged, + _unwrap, + ) + + # Avoid an isinstance check against SemanticAggregateOp by duck-typing: + # post-aggregation sources expose neither dimensions nor an enrichable + # base table, so an empty dim_map is the right behavior. + from ._core import SemanticAggregateOp + + all_roots = _find_all_root_models(self.source) + base_tbl = _to_untagged(self.source) + dim_map = ( + {} + if isinstance(self.source, SemanticAggregateOp) + else _get_merged_fields(all_roots, "dimensions") + ) + + # Enrich table with derived dimensions so multi-level deps + # (e.g. d_two -> d_one -> distance) resolve correctly in filters. + # Best-effort: skip dimensions whose columns aren't available yet + # (e.g. join-based dims); those resolve through the Resolver fallback. + enriched = base_tbl + for dim_name in dim_map: + try: + enriched = _mutate_dimensions_with_dependencies( + enriched, [dim_name], dim_map + ) + except (TypeError, KeyError, AttributeError): + pass + + pred_fn = _unwrap(self.predicate) + resolver = _Resolver(enriched, dim_map) + pred = _resolve_expr(pred_fn, resolver) + return enriched.filter(pred) + + def get_dimensions(self) -> Mapping[str, Dimension]: + """Get dictionary of dimensions from source.""" + return self.source.get_dimensions() + + def get_measures(self) -> Mapping[str, Measure]: + """Get dictionary of measures from source.""" + return self.source.get_measures() + + def get_calculated_measures(self) -> Mapping[str, Any]: + """Get dictionary of calculated measures from source.""" + return self.source.get_calculated_measures() + + +def _classify_fields( + fields: tuple[str, ...], + dimensions: dict, + measures: dict, +) -> tuple[list[str], list[str], list[str]]: + """Classify fields into dimensions, measures, and raw columns.""" + dims = [f for f in fields if f in dimensions] + meas = [f for f in fields if f in measures] + raw = [f for f in fields if f not in dimensions and f not in measures] + return dims, meas, raw + + +def _process_nested_access_marker( + marker: NestedAccessMarker, + name: str, + tbl: ir.Table, +) -> tuple[ir.Table, ir.Value]: + """Process a NestedAccessMarker to unnest and build aggregation expression.""" + unnested = tbl + for array_col in marker.array_path: + if array_col in unnested.columns: + unnested = unnested.unnest(array_col) + + if marker.operation == "count": + return unnested, unnested.count().name(name) + + expr = getattr(unnested, marker.array_path[0]) + for field_name in marker.field_path: + expr = getattr(expr, field_name) + + if marker.operation in ("sum", "mean", "min", "max", "nunique"): + agg_fn = getattr(expr, marker.operation) + return unnested, agg_fn().name(name) + + raise ValueError(f"Unknown operation: {marker.operation}") + + +def _evaluate_measures_with_unnesting( + measure_names: list[str], + measures: dict, + tbl: ir.Table, +) -> dict: + """Evaluate measures and apply automatic unnesting if needed. + + Returns dict with: + - table: potentially unnested table + - measure_exprs: list of evaluated measure expressions + - needs_unnesting: whether unnesting occurred + """ + meas_exprs = [] + current_tbl = tbl + needs_unnesting = False + + for name in measure_names: + result = measures[name](tbl) + + if isinstance(result, NestedAccessMarker): + current_tbl, meas_expr = _process_nested_access_marker(result, name, current_tbl) + meas_exprs.append(meas_expr) + needs_unnesting = True + else: + meas_exprs.append(result.name(name)) + + return { + "table": current_tbl, + "measure_exprs": meas_exprs, + "needs_unnesting": needs_unnesting, + } + + +def _build_select_or_aggregate( + tbl: ir.Table, + dim_exprs: list, + meas_exprs: list, + raw_exprs: list, +) -> ir.Table: + """Build appropriate select/aggregate based on what expressions exist.""" + if meas_exprs and dim_exprs: + return tbl.group_by(dim_exprs).aggregate(meas_exprs) + if meas_exprs: + return tbl.aggregate(meas_exprs) + if dim_exprs or raw_exprs: + return tbl.select(dim_exprs + raw_exprs) + return tbl + + +class SemanticProjectOp(Relation): + source: Relation + fields: tuple[str, ...] + + def __init__(self, source: Relation, fields: Iterable[str]) -> None: + super().__init__(source=Relation.__coerce__(source), fields=tuple(fields)) + + def __repr__(self) -> str: + return _semantic_repr(self) + + @property + def values(self) -> FrozenOrderedDict[str, Any]: + src_vals = self.source.values + return FrozenOrderedDict( + {k: v for k, v in src_vals.items() if k in self.fields}, + ) + + @property + def schema(self) -> Schema: + return _SchemaClass(fields=_FrozenOrderedDict({k: v.dtype for k, v in self.values.items()})) + + def to_untagged(self): + from ._core import _get_merged_fields, _to_untagged + + all_roots = _find_all_root_models(self.source) + tbl = _to_untagged(self.source) + + if not all_roots: + return tbl.select([getattr(tbl, f) for f in self.fields]) + + merged_dimensions = _get_merged_fields(all_roots, "dimensions") + merged_measures = _get_merged_fields(all_roots, "measures") + + dims, meas, raw_fields = _classify_fields(self.fields, merged_dimensions, merged_measures) + + # Evaluate measures and handle automatic unnesting + meas_result = _evaluate_measures_with_unnesting(meas, merged_measures, tbl) + + active_tbl = meas_result["table"] + meas_exprs = meas_result["measure_exprs"] + needs_unnesting = meas_result["needs_unnesting"] + + # Re-evaluate dimensions on unnested table if needed + dim_exprs = ( + [merged_dimensions[name](active_tbl).name(name) for name in dims] + if needs_unnesting + else [merged_dimensions[name](tbl).name(name) for name in dims] + ) + + # Get raw columns that still exist after unnesting + raw_exprs = [getattr(active_tbl, name) for name in raw_fields if name in active_tbl.columns] + + return _build_select_or_aggregate(active_tbl, dim_exprs, meas_exprs, raw_exprs) + + +class SemanticGroupByOp(Relation): + source: Relation + keys: tuple[str, ...] + + def __init__(self, source: Relation, keys: Iterable[str]) -> None: + super().__init__(source=Relation.__coerce__(source), keys=tuple(keys)) + + def __repr__(self) -> str: + return _semantic_repr(self) + + @property + def values(self) -> FrozenOrderedDict[str, Any]: + return self.source.values + + @property + def schema(self) -> Schema: + return self.source.schema + + def to_untagged(self): + from ._core import _to_untagged + + return _to_untagged(self.source) diff --git a/src/boring_semantic_layer/ops/_core.py b/src/boring_semantic_layer/ops/_core.py index 654922c..7a4a011 100644 --- a/src/boring_semantic_layer/ops/_core.py +++ b/src/boring_semantic_layer/ops/_core.py @@ -301,391 +301,16 @@ def _classify_dependencies( -class SemanticTableOp(Relation): - """Relation with semantic metadata (dimensions and measures). - - Stores ir.Table expression directly to avoid .op() → .to_expr() conversions. - - Note: Accepts both regular ibis.Table and xorq's vendored ibis.Table. - Regular ibis tables are automatically converted to xorq in __init__. - """ - - table: Any # Accepts both ir.Table and regular ibis.expr.types.Table - dimensions: FrozenDict[str, Dimension] - measures: FrozenDict[str, Measure] - calc_measures: FrozenDict[str, Any] - name: str | None = None - description: str | None = None - _source_join: Any = field( - default=None, repr=False - ) # Track if this wraps a join (SemanticJoinOp) for optimization - - def __init__( - self, - table: ir.Table, - dimensions: dict[str, Dimension] | FrozenDict[str, Dimension], - measures: dict[str, Measure] | FrozenDict[str, Measure], - calc_measures: dict[str, Any] | FrozenDict[str, Any], - name: str | None = None, - description: str | None = None, - _source_join: Any = None, - ) -> None: - # Accept both regular ibis and xorq tables without conversion - # This allows using regular ibis by default, xorq only when provided - super().__init__( - table=table, - dimensions=FrozenDict(dimensions) - if not isinstance(dimensions, FrozenDict) - else dimensions, - measures=FrozenDict(measures) if not isinstance(measures, FrozenDict) else measures, - calc_measures=FrozenDict(calc_measures) - if not isinstance(calc_measures, FrozenDict) - else calc_measures, - name=name, - description=description, - _source_join=_source_join, - ) - - def __repr__(self) -> str: - return _semantic_repr(self) - - @property - def values(self) -> FrozenOrderedDict[str, Any]: - dims = self.get_dimensions() - measures = self.get_measures() - calc_measures = self.get_calculated_measures() - # Build enriched table with all dimensions resolved (handles derived deps) - enriched = _mutate_dimensions_with_dependencies(self.table, dims.keys(), dims) - base_values = { - **{col: self.table[col].op() for col in self.table.columns}, - **{name: enriched[name].op() for name in dims}, - **{name: fn(enriched).op() for name, fn in measures.items()}, - } - # Resolve calculated measure types via a dummy table with base measure dtypes. - # ``infer_calc_dtype`` mirrors the AggregationExpr rewrite from - # ``compile_grouped_with_all`` so calc measures with inline aggregations - # (e.g. ``AllOf(AggregationExpr)``) round-trip through type inference. - if calc_measures: - from ..compile_all import _get_ibis_module, infer_calc_dtype - - measure_schema = { - name: base_values[name].dtype for name in measures if name in base_values - } - ibis_module = _get_ibis_module(enriched) - for name, expr in calc_measures.items(): - try: - compiled = infer_calc_dtype( - expr, measure_schema, enriched, ibis_module - ) - base_values[name] = compiled.op() - except Exception as e: - # Joined models with dotted column names, calc measures - # whose inline aggregations don't apply to the dummy schema, - # etc. Type info is best-effort; surface for debugging. - logger.debug( - "calc-measure type inference failed for %r: %s", name, e - ) - return FrozenOrderedDict(base_values) - - @property - def schema(self): - fields_dict = {name: str(v.dtype) for name, v in self.values.items()} - return _make_schema(fields_dict) - - @property - def json_definition(self) -> Mapping[str, Any]: - return _build_json_definition( - self.get_dimensions(), - self.get_measures(), - self.name, - self.description, - ) - - @property - def _dims(self) -> dict[str, Dimension]: - return dict(self.get_dimensions()) - - @property - def _base_measures(self) -> dict[str, Measure]: - return dict(self.get_measures()) - - @property - def _calc_measures(self) -> dict[str, Any]: - return dict(self.get_calculated_measures()) - - def get_measures(self) -> Mapping[str, Measure]: - """Get dictionary of base measures with metadata.""" - return object.__getattribute__(self, "measures") - - def get_dimensions(self) -> Mapping[str, Dimension]: - """Get dictionary of dimensions with metadata.""" - return object.__getattribute__(self, "dimensions") - - def get_calculated_measures(self) -> Mapping[str, Any]: - """Get dictionary of calculated measures with metadata.""" - return self.calc_measures - - def get_graph(self) -> dict[str, dict[str, Any]]: - from ..graph_utils import build_dependency_graph - - return build_dependency_graph( - self.get_dimensions(), - self.get_measures(), - self.get_calculated_measures(), - self.table, - ) - - def __getattribute__(self, name: str): - """Override attribute access to return tuples for dimensions/measures. - - This provides a cleaner API where .dimensions returns ('dim1', 'dim2') - instead of the full FrozenDict. Use get_dimensions() to get the full dict. - """ - # For special/internal attributes (dunder methods), use default behavior - # This is critical for xorq's vendored ibis which uses __precomputed_hash__, etc. - if name.startswith("__") and name.endswith("__"): - return object.__getattribute__(self, name) - - # Custom behavior for dimensions and measures - if name == "dimensions": - dims = object.__getattribute__(self, "dimensions") - return tuple(dims.keys()) - if name == "measures": - base_meas = object.__getattribute__(self, "measures") - calc_meas = object.__getattribute__(self, "calc_measures") - return tuple(base_meas.keys()) + tuple(calc_meas.keys()) - - # Default behavior for everything else - return object.__getattribute__(self, name) - - def to_untagged(self): - # Conversion happens at SemanticModel construction; self.table is - # already xorq when supported, plain ibis when not. - return self.table - - -class SemanticFilterOp(Relation): - source: Relation - predicate: Callable - - def __init__(self, source: Relation, predicate: Callable) -> None: - super().__init__( - source=Relation.__coerce__(source), - predicate=_ensure_wrapped(predicate), - ) - - def __repr__(self) -> str: - return _semantic_repr(self) - - @property - def values(self) -> FrozenOrderedDict[str, Any]: - return self.source.values - - @property - def schema(self) -> Schema: - return self.source.schema - - def to_untagged(self): - from ..convert import _Resolver - - all_roots = _find_all_root_models(self.source) - base_tbl = _to_untagged(self.source) - dim_map = ( - {} - if isinstance(self.source, SemanticAggregateOp) - else _get_merged_fields(all_roots, "dimensions") - ) - - # Enrich table with derived dimensions so multi-level deps - # (e.g. d_two -> d_one -> distance) resolve correctly in filters. - # Best-effort: skip dimensions whose columns aren't available yet - # (e.g. join-based dims); those resolve through the Resolver fallback. - enriched = base_tbl - for dim_name in dim_map: - try: - enriched = _mutate_dimensions_with_dependencies( - enriched, [dim_name], dim_map - ) - except (TypeError, KeyError, AttributeError): - pass - - pred_fn = _unwrap(self.predicate) - resolver = _Resolver(enriched, dim_map) - pred = _resolve_expr(pred_fn, resolver) - return enriched.filter(pred) - - def get_dimensions(self) -> Mapping[str, Dimension]: - """Get dictionary of dimensions from source.""" - return self.source.get_dimensions() - - def get_measures(self) -> Mapping[str, Measure]: - """Get dictionary of measures from source.""" - return self.source.get_measures() - - def get_calculated_measures(self) -> Mapping[str, Any]: - """Get dictionary of calculated measures from source.""" - return self.source.get_calculated_measures() - - -def _classify_fields( - fields: tuple[str, ...], - dimensions: dict, - measures: dict, -) -> tuple[list[str], list[str], list[str]]: - """Classify fields into dimensions, measures, and raw columns.""" - dims = [f for f in fields if f in dimensions] - meas = [f for f in fields if f in measures] - raw = [f for f in fields if f not in dimensions and f not in measures] - return dims, meas, raw - - -def _process_nested_access_marker( - marker: NestedAccessMarker, - name: str, - tbl: ir.Table, -) -> tuple[ir.Table, ir.Value]: - """Process a NestedAccessMarker to unnest and build aggregation expression.""" - unnested = tbl - for array_col in marker.array_path: - if array_col in unnested.columns: - unnested = unnested.unnest(array_col) - - if marker.operation == "count": - return unnested, unnested.count().name(name) - - expr = getattr(unnested, marker.array_path[0]) - for field_name in marker.field_path: - expr = getattr(expr, field_name) - - if marker.operation in ("sum", "mean", "min", "max", "nunique"): - agg_fn = getattr(expr, marker.operation) - return unnested, agg_fn().name(name) - - raise ValueError(f"Unknown operation: {marker.operation}") - - -def _evaluate_measures_with_unnesting( - measure_names: list[str], - measures: dict, - tbl: ir.Table, -) -> dict: - """Evaluate measures and apply automatic unnesting if needed. - - Returns dict with: - - table: potentially unnested table - - measure_exprs: list of evaluated measure expressions - - needs_unnesting: whether unnesting occurred - """ - meas_exprs = [] - current_tbl = tbl - needs_unnesting = False - - for name in measure_names: - result = measures[name](tbl) - - if isinstance(result, NestedAccessMarker): - current_tbl, meas_expr = _process_nested_access_marker(result, name, current_tbl) - meas_exprs.append(meas_expr) - needs_unnesting = True - else: - meas_exprs.append(result.name(name)) - - return { - "table": current_tbl, - "measure_exprs": meas_exprs, - "needs_unnesting": needs_unnesting, - } - - -def _build_select_or_aggregate( - tbl: ir.Table, - dim_exprs: list, - meas_exprs: list, - raw_exprs: list, -) -> ir.Table: - """Build appropriate select/aggregate based on what expressions exist.""" - if meas_exprs and dim_exprs: - return tbl.group_by(dim_exprs).aggregate(meas_exprs) - if meas_exprs: - return tbl.aggregate(meas_exprs) - if dim_exprs or raw_exprs: - return tbl.select(dim_exprs + raw_exprs) - return tbl - - -class SemanticProjectOp(Relation): - source: Relation - fields: tuple[str, ...] - - def __init__(self, source: Relation, fields: Iterable[str]) -> None: - super().__init__(source=Relation.__coerce__(source), fields=tuple(fields)) - - def __repr__(self) -> str: - return _semantic_repr(self) - - @property - def values(self) -> FrozenOrderedDict[str, Any]: - src_vals = self.source.values - return FrozenOrderedDict( - {k: v for k, v in src_vals.items() if k in self.fields}, - ) - - @property - def schema(self) -> Schema: - return _SchemaClass(fields=_FrozenOrderedDict({k: v.dtype for k, v in self.values.items()})) - - def to_untagged(self): - all_roots = _find_all_root_models(self.source) - tbl = _to_untagged(self.source) - - if not all_roots: - return tbl.select([getattr(tbl, f) for f in self.fields]) - - merged_dimensions = _get_merged_fields(all_roots, "dimensions") - merged_measures = _get_merged_fields(all_roots, "measures") - - dims, meas, raw_fields = _classify_fields(self.fields, merged_dimensions, merged_measures) - - # Evaluate measures and handle automatic unnesting - meas_result = _evaluate_measures_with_unnesting(meas, merged_measures, tbl) - - active_tbl = meas_result["table"] - meas_exprs = meas_result["measure_exprs"] - needs_unnesting = meas_result["needs_unnesting"] - - # Re-evaluate dimensions on unnested table if needed - dim_exprs = ( - [merged_dimensions[name](active_tbl).name(name) for name in dims] - if needs_unnesting - else [merged_dimensions[name](tbl).name(name) for name in dims] - ) - - # Get raw columns that still exist after unnesting - raw_exprs = [getattr(active_tbl, name) for name in raw_fields if name in active_tbl.columns] - - return _build_select_or_aggregate(active_tbl, dim_exprs, meas_exprs, raw_exprs) - - -class SemanticGroupByOp(Relation): - source: Relation - keys: tuple[str, ...] - - def __init__(self, source: Relation, keys: Iterable[str]) -> None: - super().__init__(source=Relation.__coerce__(source), keys=tuple(keys)) - - def __repr__(self) -> str: - return _semantic_repr(self) - - @property - def values(self) -> FrozenOrderedDict[str, Any]: - return self.source.values - - @property - def schema(self) -> Schema: - return self.source.schema - - def to_untagged(self): - return _to_untagged(self.source) +from ._basic_ops import ( # noqa: E402, F401 + SemanticFilterOp, + SemanticGroupByOp, + SemanticProjectOp, + SemanticTableOp, + _build_select_or_aggregate, + _classify_fields, + _evaluate_measures_with_unnesting, + _process_nested_access_marker, +) @frozen