diff --git a/src/boring_semantic_layer/ops/__init__.py b/src/boring_semantic_layer/ops/__init__.py new file mode 100644 index 0000000..11b4158 --- /dev/null +++ b/src/boring_semantic_layer/ops/__init__.py @@ -0,0 +1,67 @@ +"""Semantic layer operations. + +This package was split out of a single ``ops.py``. ``__init__`` re-exports +the names that callers import from ``boring_semantic_layer.ops`` so existing +imports keep working unchanged. +""" + +from ._column_extraction import ( + ColumnExtractionResult, + ColumnTracker, + JoinColumnExtractionResult, + TableColumnRequirements, + _extract_columns_from_callable, + _extract_join_key_columns, + _extract_requirements_from_keys, + _extract_requirements_from_measures, + _make_tracking_proxy, + _parse_prefixed_field, +) +from ._core import ( + # Public value objects + Dimension, + Measure, + # Public Op classes + SemanticAggregateOp, + SemanticFilterOp, + SemanticGroupByOp, + SemanticIndexOp, + SemanticJoinOp, + SemanticLimitOp, + SemanticMutateOp, + SemanticOrderByOp, + SemanticProjectOp, + SemanticTableOp, + SemanticUnnestOp, + # Private helpers used by other modules in this package + _CallableWrapper, + _classify_measure, + _collect_measure_refs, + _ensure_xorq_table, + _find_all_root_models, + _get_field_dict, + _get_merged_fields, + _is_deferred, + _make_schema, + _merge_fields_with_prefixing, + _normalize_join_predicate, + _normalize_to_name, + _rebind_to_backend, + _rebind_to_canonical_backend, +) + +__all__ = [ + "Dimension", + "Measure", + "SemanticAggregateOp", + "SemanticFilterOp", + "SemanticGroupByOp", + "SemanticIndexOp", + "SemanticJoinOp", + "SemanticLimitOp", + "SemanticMutateOp", + "SemanticOrderByOp", + "SemanticProjectOp", + "SemanticTableOp", + "SemanticUnnestOp", +] diff --git a/src/boring_semantic_layer/ops/_basic_ops.py b/src/boring_semantic_layer/ops/_basic_ops.py new file mode 100644 index 0000000..a4d6d22 --- /dev/null +++ b/src/boring_semantic_layer/ops/_basic_ops.py @@ -0,0 +1,436 @@ +"""Foundation Op classes: Table, Filter, Project, GroupBy. + +These four are the simpler, mostly-pass-through ops. SemanticTableOp is +the leaf that holds dimension/measure metadata; the others wrap it +(or a downstream op) and contribute one operation each. The aggregate, +join, and order/limit ops live in their own modules. +""" + +from __future__ import annotations + +from collections.abc import Callable, Iterable, Mapping +from typing import Any + +from attrs import field +from ibis.expr import types as ir +from ibis.expr.operations.relations import Relation +from ibis.expr.schema import Schema + +from .._xorq import FrozenDict, FrozenOrderedDict, Schema as XorqSchema +from ..nested_access import NestedAccessMarker +from ._callable import _ensure_wrapped +from ._format import _semantic_repr +from ._measure_helpers import _build_json_definition +from ._root_models import _find_all_root_models +from ._values import Dimension, Measure + + +_SchemaClass = XorqSchema +_FrozenOrderedDict = FrozenOrderedDict + + +class SemanticTableOp(Relation): + """Relation with semantic metadata (dimensions and measures). + + Stores ir.Table expression directly to avoid .op() → .to_expr() conversions. + + Note: Accepts both regular ibis.Table and xorq's vendored ibis.Table. + Regular ibis tables are automatically converted to xorq in __init__. + """ + + table: Any # Accepts both ir.Table and regular ibis.expr.types.Table + dimensions: FrozenDict[str, Dimension] + measures: FrozenDict[str, Measure] + calc_measures: FrozenDict[str, Any] + name: str | None = None + description: str | None = None + _source_join: Any = field( + default=None, repr=False + ) # Track if this wraps a join (SemanticJoinOp) for optimization + + def __init__( + self, + table: ir.Table, + dimensions: dict[str, Dimension] | FrozenDict[str, Dimension], + measures: dict[str, Measure] | FrozenDict[str, Measure], + calc_measures: dict[str, Any] | FrozenDict[str, Any], + name: str | None = None, + description: str | None = None, + _source_join: Any = None, + ) -> None: + # Accept both regular ibis and xorq tables without conversion + # This allows using regular ibis by default, xorq only when provided + super().__init__( + table=table, + dimensions=FrozenDict(dimensions) + if not isinstance(dimensions, FrozenDict) + else dimensions, + measures=FrozenDict(measures) if not isinstance(measures, FrozenDict) else measures, + calc_measures=FrozenDict(calc_measures) + if not isinstance(calc_measures, FrozenDict) + else calc_measures, + name=name, + description=description, + _source_join=_source_join, + ) + + def __repr__(self) -> str: + return _semantic_repr(self) + + @property + def values(self) -> FrozenOrderedDict[str, Any]: + from ._core import _make_schema, _mutate_dimensions_with_dependencies, logger + + dims = self.get_dimensions() + measures = self.get_measures() + calc_measures = self.get_calculated_measures() + # Build enriched table with all dimensions resolved (handles derived deps) + enriched = _mutate_dimensions_with_dependencies(self.table, dims.keys(), dims) + base_values = { + **{col: self.table[col].op() for col in self.table.columns}, + **{name: enriched[name].op() for name in dims}, + **{name: fn(enriched).op() for name, fn in measures.items()}, + } + # Resolve calculated measure types via a dummy table with base measure dtypes. + # ``infer_calc_dtype`` mirrors the AggregationExpr rewrite from + # ``compile_grouped_with_all`` so calc measures with inline aggregations + # (e.g. ``AllOf(AggregationExpr)``) round-trip through type inference. + if calc_measures: + from ..compile_all import _get_ibis_module, infer_calc_dtype + + measure_schema = { + name: base_values[name].dtype for name in measures if name in base_values + } + ibis_module = _get_ibis_module(enriched) + for name, expr in calc_measures.items(): + try: + compiled = infer_calc_dtype( + expr, measure_schema, enriched, ibis_module + ) + base_values[name] = compiled.op() + except Exception as e: + # Joined models with dotted column names, calc measures + # whose inline aggregations don't apply to the dummy schema, + # etc. Type info is best-effort; surface for debugging. + logger.debug( + "calc-measure type inference failed for %r: %s", name, e + ) + return FrozenOrderedDict(base_values) + + @property + def schema(self): + from ._core import _make_schema + + fields_dict = {name: str(v.dtype) for name, v in self.values.items()} + return _make_schema(fields_dict) + + @property + def json_definition(self) -> Mapping[str, Any]: + return _build_json_definition( + self.get_dimensions(), + self.get_measures(), + self.name, + self.description, + ) + + @property + def _dims(self) -> dict[str, Dimension]: + return dict(self.get_dimensions()) + + @property + def _base_measures(self) -> dict[str, Measure]: + return dict(self.get_measures()) + + @property + def _calc_measures(self) -> dict[str, Any]: + return dict(self.get_calculated_measures()) + + def get_measures(self) -> Mapping[str, Measure]: + """Get dictionary of base measures with metadata.""" + return object.__getattribute__(self, "measures") + + def get_dimensions(self) -> Mapping[str, Dimension]: + """Get dictionary of dimensions with metadata.""" + return object.__getattribute__(self, "dimensions") + + def get_calculated_measures(self) -> Mapping[str, Any]: + """Get dictionary of calculated measures with metadata.""" + return self.calc_measures + + def get_graph(self) -> dict[str, dict[str, Any]]: + from ..graph_utils import build_dependency_graph + + return build_dependency_graph( + self.get_dimensions(), + self.get_measures(), + self.get_calculated_measures(), + self.table, + ) + + def __getattribute__(self, name: str): + """Override attribute access to return tuples for dimensions/measures. + + This provides a cleaner API where .dimensions returns ('dim1', 'dim2') + instead of the full FrozenDict. Use get_dimensions() to get the full dict. + """ + # For special/internal attributes (dunder methods), use default behavior + # This is critical for xorq's vendored ibis which uses __precomputed_hash__, etc. + if name.startswith("__") and name.endswith("__"): + return object.__getattribute__(self, name) + + # Custom behavior for dimensions and measures + if name == "dimensions": + dims = object.__getattribute__(self, "dimensions") + return tuple(dims.keys()) + if name == "measures": + base_meas = object.__getattribute__(self, "measures") + calc_meas = object.__getattribute__(self, "calc_measures") + return tuple(base_meas.keys()) + tuple(calc_meas.keys()) + + # Default behavior for everything else + return object.__getattribute__(self, name) + + def to_untagged(self): + # Conversion happens at SemanticModel construction; self.table is + # already xorq when supported, plain ibis when not. + return self.table + + +class SemanticFilterOp(Relation): + source: Relation + predicate: Callable + + def __init__(self, source: Relation, predicate: Callable) -> None: + super().__init__( + source=Relation.__coerce__(source), + predicate=_ensure_wrapped(predicate), + ) + + def __repr__(self) -> str: + return _semantic_repr(self) + + @property + def values(self) -> FrozenOrderedDict[str, Any]: + return self.source.values + + @property + def schema(self) -> Schema: + return self.source.schema + + def to_untagged(self): + from ..convert import _Resolver + from ._core import ( + _get_merged_fields, + _mutate_dimensions_with_dependencies, + _resolve_expr, + _to_untagged, + _unwrap, + ) + + # Avoid an isinstance check against SemanticAggregateOp by duck-typing: + # post-aggregation sources expose neither dimensions nor an enrichable + # base table, so an empty dim_map is the right behavior. + from ._core import SemanticAggregateOp + + all_roots = _find_all_root_models(self.source) + base_tbl = _to_untagged(self.source) + dim_map = ( + {} + if isinstance(self.source, SemanticAggregateOp) + else _get_merged_fields(all_roots, "dimensions") + ) + + # Enrich table with derived dimensions so multi-level deps + # (e.g. d_two -> d_one -> distance) resolve correctly in filters. + # Best-effort: skip dimensions whose columns aren't available yet + # (e.g. join-based dims); those resolve through the Resolver fallback. + enriched = base_tbl + for dim_name in dim_map: + try: + enriched = _mutate_dimensions_with_dependencies( + enriched, [dim_name], dim_map + ) + except (TypeError, KeyError, AttributeError): + pass + + pred_fn = _unwrap(self.predicate) + resolver = _Resolver(enriched, dim_map) + pred = _resolve_expr(pred_fn, resolver) + return enriched.filter(pred) + + def get_dimensions(self) -> Mapping[str, Dimension]: + """Get dictionary of dimensions from source.""" + return self.source.get_dimensions() + + def get_measures(self) -> Mapping[str, Measure]: + """Get dictionary of measures from source.""" + return self.source.get_measures() + + def get_calculated_measures(self) -> Mapping[str, Any]: + """Get dictionary of calculated measures from source.""" + return self.source.get_calculated_measures() + + +def _classify_fields( + fields: tuple[str, ...], + dimensions: dict, + measures: dict, +) -> tuple[list[str], list[str], list[str]]: + """Classify fields into dimensions, measures, and raw columns.""" + dims = [f for f in fields if f in dimensions] + meas = [f for f in fields if f in measures] + raw = [f for f in fields if f not in dimensions and f not in measures] + return dims, meas, raw + + +def _process_nested_access_marker( + marker: NestedAccessMarker, + name: str, + tbl: ir.Table, +) -> tuple[ir.Table, ir.Value]: + """Process a NestedAccessMarker to unnest and build aggregation expression.""" + unnested = tbl + for array_col in marker.array_path: + if array_col in unnested.columns: + unnested = unnested.unnest(array_col) + + if marker.operation == "count": + return unnested, unnested.count().name(name) + + expr = getattr(unnested, marker.array_path[0]) + for field_name in marker.field_path: + expr = getattr(expr, field_name) + + if marker.operation in ("sum", "mean", "min", "max", "nunique"): + agg_fn = getattr(expr, marker.operation) + return unnested, agg_fn().name(name) + + raise ValueError(f"Unknown operation: {marker.operation}") + + +def _evaluate_measures_with_unnesting( + measure_names: list[str], + measures: dict, + tbl: ir.Table, +) -> dict: + """Evaluate measures and apply automatic unnesting if needed. + + Returns dict with: + - table: potentially unnested table + - measure_exprs: list of evaluated measure expressions + - needs_unnesting: whether unnesting occurred + """ + meas_exprs = [] + current_tbl = tbl + needs_unnesting = False + + for name in measure_names: + result = measures[name](tbl) + + if isinstance(result, NestedAccessMarker): + current_tbl, meas_expr = _process_nested_access_marker(result, name, current_tbl) + meas_exprs.append(meas_expr) + needs_unnesting = True + else: + meas_exprs.append(result.name(name)) + + return { + "table": current_tbl, + "measure_exprs": meas_exprs, + "needs_unnesting": needs_unnesting, + } + + +def _build_select_or_aggregate( + tbl: ir.Table, + dim_exprs: list, + meas_exprs: list, + raw_exprs: list, +) -> ir.Table: + """Build appropriate select/aggregate based on what expressions exist.""" + if meas_exprs and dim_exprs: + return tbl.group_by(dim_exprs).aggregate(meas_exprs) + if meas_exprs: + return tbl.aggregate(meas_exprs) + if dim_exprs or raw_exprs: + return tbl.select(dim_exprs + raw_exprs) + return tbl + + +class SemanticProjectOp(Relation): + source: Relation + fields: tuple[str, ...] + + def __init__(self, source: Relation, fields: Iterable[str]) -> None: + super().__init__(source=Relation.__coerce__(source), fields=tuple(fields)) + + def __repr__(self) -> str: + return _semantic_repr(self) + + @property + def values(self) -> FrozenOrderedDict[str, Any]: + src_vals = self.source.values + return FrozenOrderedDict( + {k: v for k, v in src_vals.items() if k in self.fields}, + ) + + @property + def schema(self) -> Schema: + return _SchemaClass(fields=_FrozenOrderedDict({k: v.dtype for k, v in self.values.items()})) + + def to_untagged(self): + from ._core import _get_merged_fields, _to_untagged + + all_roots = _find_all_root_models(self.source) + tbl = _to_untagged(self.source) + + if not all_roots: + return tbl.select([getattr(tbl, f) for f in self.fields]) + + merged_dimensions = _get_merged_fields(all_roots, "dimensions") + merged_measures = _get_merged_fields(all_roots, "measures") + + dims, meas, raw_fields = _classify_fields(self.fields, merged_dimensions, merged_measures) + + # Evaluate measures and handle automatic unnesting + meas_result = _evaluate_measures_with_unnesting(meas, merged_measures, tbl) + + active_tbl = meas_result["table"] + meas_exprs = meas_result["measure_exprs"] + needs_unnesting = meas_result["needs_unnesting"] + + # Re-evaluate dimensions on unnested table if needed + dim_exprs = ( + [merged_dimensions[name](active_tbl).name(name) for name in dims] + if needs_unnesting + else [merged_dimensions[name](tbl).name(name) for name in dims] + ) + + # Get raw columns that still exist after unnesting + raw_exprs = [getattr(active_tbl, name) for name in raw_fields if name in active_tbl.columns] + + return _build_select_or_aggregate(active_tbl, dim_exprs, meas_exprs, raw_exprs) + + +class SemanticGroupByOp(Relation): + source: Relation + keys: tuple[str, ...] + + def __init__(self, source: Relation, keys: Iterable[str]) -> None: + super().__init__(source=Relation.__coerce__(source), keys=tuple(keys)) + + def __repr__(self) -> str: + return _semantic_repr(self) + + @property + def values(self) -> FrozenOrderedDict[str, Any]: + return self.source.values + + @property + def schema(self) -> Schema: + return self.source.schema + + def to_untagged(self): + from ._core import _to_untagged + + return _to_untagged(self.source) diff --git a/src/boring_semantic_layer/ops/_callable.py b/src/boring_semantic_layer/ops/_callable.py new file mode 100644 index 0000000..41ace49 --- /dev/null +++ b/src/boring_semantic_layer/ops/_callable.py @@ -0,0 +1,60 @@ +"""Hashable callable wrapper used to put lambdas/Deferred into FrozenDict. + +Both raw callables and user-side ``Deferred`` instances are not hashable +in their bare form, but ibis Op classes store dimension/measure exprs in +``FrozenDict``. ``_CallableWrapper`` gives them an identity-based hash so +they can be persisted in op fields. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from attrs import frozen + + +@frozen +class _CallableWrapper: + """Hashable wrapper for Callable and Deferred. + + Both raw callables (lambda) and user Deferred (_.foo) are not hashable + and cannot be stored in FrozenDict. This wrapper provides hashability + using identity-based hashing. + """ + + _fn: Any + + def __call__(self, *args, **kwargs): + return self._fn(*args, **kwargs) + + def __hash__(self): + # should this be dask.base.tokenize()? + return hash(id(self._fn)) + + @property + def unwrap(self): + return self._fn + + +def _ensure_wrapped(fn: Any) -> _CallableWrapper: + """Wrap Callable or Deferred for hashability.""" + return fn if isinstance(fn, _CallableWrapper) else _CallableWrapper(fn) + + +def _infer_unnest(fn: Callable, table: Any) -> tuple[str, ...]: + """Infer required unnest operations from the table. + + Examples: + to_semantic_table(tbl).with_measures(...) -> () # Session level + to_semantic_table(tbl).unnest("hits").with_measures(...) -> ("hits",) + unnested.unnest("product").with_measures(...) -> ("product",) + """ + from ..expr import SemanticUnnest + + if isinstance(table, SemanticUnnest): + op = table.op() + # SemanticUnnestOp always has column attribute + return (op.column,) + + return () diff --git a/src/boring_semantic_layer/ops/_column_extraction.py b/src/boring_semantic_layer/ops/_column_extraction.py new file mode 100644 index 0000000..879aaeb --- /dev/null +++ b/src/boring_semantic_layer/ops/_column_extraction.py @@ -0,0 +1,369 @@ +"""Column tracking & per-table column-requirement extraction. + +Used by projection pushdown to figure out which columns each leaf table +must produce. Split from ``_core.py`` to keep the relational ops module +focused on Op definitions. +""" + +from __future__ import annotations + +import logging +from collections.abc import Callable, Iterable, Mapping, Sequence +from typing import Any + +from attrs import field, frozen +from ibis.expr import operations as ibis_ops +from ibis.expr import types as ir + +from .._xorq import FrozenDict +from ..graph_utils import walk_nodes + +logger = logging.getLogger(__name__) + + +# ============================================================================== +# Column Tracking for Projection Pushdown +# ============================================================================== + + +@frozen +class ColumnTracker: + """Immutable tracker for column references during expression evaluation. + + Uses frozenset for tracked columns. New columns are added by creating + new tracker instances with updated sets. + """ + + columns: frozenset[str] = field(factory=frozenset, converter=frozenset) + + def with_column(self, col_name: str) -> ColumnTracker: + """Return new tracker with additional column.""" + return ColumnTracker(columns=self.columns | {col_name}) + + def merge(self, other: ColumnTracker) -> ColumnTracker: + """Return new tracker with merged columns.""" + return ColumnTracker(columns=self.columns | other.columns) + + +@frozen +class ColumnExtractionResult: + """Result of column extraction with error handling. + + Separates successful extraction from error cases. + """ + + columns: frozenset[str] = field(factory=frozenset, converter=frozenset) + extraction_failed: bool = False + error_type: type[Exception] | None = None + + @classmethod + def success(cls, columns: set[str] | frozenset[str]) -> ColumnExtractionResult: + """Create successful result.""" + return cls(columns=frozenset(columns), extraction_failed=False) + + @classmethod + def failure(cls, error: Exception) -> ColumnExtractionResult: + """Create failure result with error information.""" + return cls( + columns=frozenset(), + extraction_failed=True, + error_type=type(error), + ) + + def is_success(self) -> bool: + """Check if extraction succeeded.""" + return not self.extraction_failed + + +@frozen +class JoinColumnExtractionResult: + """Result of join column extraction for both tables.""" + + left_columns: frozenset[str] = field(factory=frozenset, converter=frozenset) + right_columns: frozenset[str] = field(factory=frozenset, converter=frozenset) + extraction_failed: bool = False + error_type: type[Exception] | None = None + + @classmethod + def success( + cls, + left: set[str] | frozenset[str], + right: set[str] | frozenset[str], + ) -> JoinColumnExtractionResult: + """Create successful result.""" + return cls( + left_columns=frozenset(left), + right_columns=frozenset(right), + extraction_failed=False, + ) + + @classmethod + def failure(cls, error: Exception) -> JoinColumnExtractionResult: + """Create failure result with error information.""" + return cls( + left_columns=frozenset(), + right_columns=frozenset(), + extraction_failed=True, + error_type=type(error), + ) + + def is_success(self) -> bool: + """Check if extraction succeeded.""" + return not self.extraction_failed + + +def _make_tracking_proxy( + table: ir.Table, + on_access: Callable[[str], None], +) -> Any: + """Create tracking proxy with custom access handler. + + Composable factory that enables different tracking strategies + via the on_access callback. + """ + + class _TrackingProxy: + """Proxy that tracks attribute and item access.""" + + def __init__(self, inner_table: ir.Table, access_handler: Callable[[str], None]): + object.__setattr__(self, "_table", inner_table) + object.__setattr__(self, "_on_access", access_handler) + + def __getattr__(self, name: str): + if name.startswith("_"): + return getattr(self._table, name) + self._on_access(name) + return getattr(self._table, name) + + def __getitem__(self, name: str): + self._on_access(name) + return self._table[name] + + return _TrackingProxy(table, on_access) + + +def _extract_columns_from_callable( + fn: Any, + table: ir.Table, +) -> ColumnExtractionResult: + """Extract column names referenced by a callable. + + Uses immutable tracking and returns structured result. + """ + if not callable(fn): + return ColumnExtractionResult.success(frozenset()) + + tracker_ref = [ColumnTracker()] + + def on_column_access(col_name: str) -> None: + tracker_ref[0] = tracker_ref[0].with_column(col_name) + + try: + tracking_proxy = _make_tracking_proxy(table, on_column_access) + fn(tracking_proxy) + return ColumnExtractionResult.success(tracker_ref[0].columns) + + except Exception as e: + return ColumnExtractionResult.failure(e) + + +def _extract_join_key_columns( + on: Callable[[Any, Any], ir.BooleanValue], + left_table: ir.Table, + right_table: ir.Table, +) -> JoinColumnExtractionResult: + left_tracker_ref = [ColumnTracker()] + right_tracker_ref = [ColumnTracker()] + + def on_left_access(col_name: str) -> None: + left_tracker_ref[0] = left_tracker_ref[0].with_column(col_name) + + def on_right_access(col_name: str) -> None: + right_tracker_ref[0] = right_tracker_ref[0].with_column(col_name) + + try: + left_proxy = _make_tracking_proxy(left_table, on_left_access) + right_proxy = _make_tracking_proxy(right_table, on_right_access) + on(left_proxy, right_proxy) + + return JoinColumnExtractionResult.success( + left_tracker_ref[0].columns, + right_tracker_ref[0].columns, + ) + + except Exception as e: + return JoinColumnExtractionResult.failure(e) + + +# ============================================================================== +# Table Column Requirements +# ============================================================================== + + +@frozen +class TableColumnRequirements: + """Immutable representation of column requirements per table. + + Maps table names to sets of required column names. + """ + + requirements: FrozenDict[str, frozenset[str]] = field( + factory=lambda: FrozenDict({}), + converter=lambda d: FrozenDict( + {k: frozenset(v) if not isinstance(v, frozenset) else v for k, v in d.items()}, + ), + ) + + def with_column(self, table_name: str, col_name: str) -> TableColumnRequirements: + """Return new requirements with additional column for table.""" + current_cols = self.requirements.get(table_name, frozenset()) + updated_cols = current_cols | {col_name} + + return TableColumnRequirements( + requirements=dict(self.requirements) | {table_name: updated_cols}, + ) + + def with_columns( + self, + table_name: str, + col_names: Iterable[str], + ) -> TableColumnRequirements: + """Return new requirements with multiple columns for table.""" + current_cols = self.requirements.get(table_name, frozenset()) + updated_cols = current_cols | frozenset(col_names) + + return TableColumnRequirements( + requirements=dict(self.requirements) | {table_name: updated_cols}, + ) + + def merge(self, other: TableColumnRequirements) -> TableColumnRequirements: + """Merge requirements from another instance.""" + merged_dict = dict(self.requirements) + + for table, cols in other.requirements.items(): + if table in merged_dict: + merged_dict[table] = merged_dict[table] | cols + else: + merged_dict[table] = cols + + return TableColumnRequirements(requirements=merged_dict) + + def to_dict(self) -> dict[str, set[str]]: + """Convert to mutable dict for API compatibility.""" + return {table: set(cols) for table, cols in self.requirements.items()} + + +def _parse_prefixed_field(field_name: str) -> tuple[str | None, str]: + """Parse potentially prefixed field name. + + Args: + field_name: Field name, possibly prefixed (e.g., "table.column") + + Returns: + Tuple of (table_name or None, column_name) + """ + if "." in field_name: + table, col = field_name.split(".", 1) + return (table, col) + return (None, field_name) + + +def _extract_requirements_from_keys( + keys: Iterable[str], + merged_dimensions: Mapping[str, Any], + all_roots: Sequence[Any], + table: ir.Table, +) -> TableColumnRequirements: + """Extract column requirements from group-by keys using graph traversal.""" + requirements = TableColumnRequirements() + + for key in keys: + table_name, col_name = _parse_prefixed_field(key) + + if table_name: + # Prefixed: we know the table + requirements = requirements.with_column(table_name, col_name) + else: + # Unprefixed: resolve dimension or use conservative fallback + if key in merged_dimensions: + dim_fn = merged_dimensions[key] + + try: + # Evaluate the dimension to get an Ibis expression + dim_expr = dim_fn(table) + + # Walk the expression graph to find all Field nodes (column references) + field_names = {node.name for node in walk_nodes(ibis_ops.Field, dim_expr)} + + # Filter to only actual columns in the table schema + actual_cols = {col for col in field_names if col in table.columns} + + if actual_cols: + for root in all_roots: + if root.name: + requirements = requirements.with_columns(root.name, actual_cols) + else: + # Fallback: assume key name is column name + for root in all_roots: + if root.name: + requirements = requirements.with_column(root.name, key) + except Exception: + logger.debug( + "dimension graph traversal failed for %r; " + "treating key name as column name", + key, + exc_info=True, + ) + for root in all_roots: + if root.name: + requirements = requirements.with_column(root.name, key) + else: + # Raw column + for root in all_roots: + if root.name: + requirements = requirements.with_column(root.name, key) + + return requirements + + +def _extract_requirements_from_measures( + aggs: Mapping[str, Callable], + all_roots: Sequence[Any], + table: ir.Table, +) -> TableColumnRequirements: + """Extract column requirements from measure aggregations using graph traversal.""" + from ._core import _unwrap + + requirements = TableColumnRequirements() + + for measure_name, measure_fn in aggs.items(): + fn = _unwrap(measure_fn) + + try: + # Evaluate the measure to get an Ibis expression + measure_expr = fn(table) + + # Walk the expression graph to find all Field nodes (column references) + field_names = {node.name for node in walk_nodes(ibis_ops.Field, measure_expr)} + + # Filter to only actual columns in the table schema + actual_cols = {col for col in field_names if col in table.columns} + + if actual_cols: + for root in all_roots: + if root.name: + requirements = requirements.with_columns(root.name, actual_cols) + except Exception: + logger.debug( + "measure graph traversal failed for %r; " + "falling back to name-based column inference", + measure_name, + exc_info=True, + ) + # Conservative fallback: if measure name looks like a column, include it + if measure_name.isidentifier(): + for root in all_roots: + if root.name: + requirements = requirements.with_column(root.name, measure_name) + + return requirements diff --git a/src/boring_semantic_layer/ops.py b/src/boring_semantic_layer/ops/_core.py similarity index 55% rename from src/boring_semantic_layer/ops.py rename to src/boring_semantic_layer/ops/_core.py index 695a2a4..7a4a011 100644 --- a/src/boring_semantic_layer/ops.py +++ b/src/boring_semantic_layer/ops/_core.py @@ -16,7 +16,7 @@ from ibis.expr.operations.relations import Field, Relation from ibis.expr.schema import Schema -from ._xorq import ( +from .._xorq import ( FrozenDict, FrozenOrderedDict, Schema as XorqSchema, @@ -44,10 +44,10 @@ def _reductions_for_expr(expr): from returns.result import Success, safe from toolz import curry -from . import projection_utils -from .compile_all import compile_grouped_with_all -from .graph_utils import walk_nodes -from .measure_scope import ( +from .. import projection_utils +from ..compile_all import compile_grouped_with_all +from ..graph_utils import walk_nodes +from ..measure_scope import ( AggregationExpr, AllOf, BinOp, @@ -56,7 +56,7 @@ def _reductions_for_expr(expr): MeasureScope, MethodCall, ) -from .nested_access import NestedAccessMarker +from ..nested_access import NestedAccessMarker logger = logging.getLogger(__name__) @@ -91,107 +91,11 @@ def __getattr__(self, name): return getattr(self._table, mapped) -def _is_deferred(expr) -> bool: - # Duck-type check: works for both ibis and xorq Deferred objects - return hasattr(expr, "_resolver") and hasattr(expr, "resolve") - - -def _normalize_to_name(arg: str | Deferred) -> str: - """Convert a string or simple ``_.name`` Deferred to a plain string name. - - Accepts a plain string (returned as-is) or a Deferred whose resolver is a - simple attribute access on the top-level ``_`` variable (e.g. ``_.origin``). - - Complex expressions like ``_.distance.sum()`` or ``_.a.b`` are rejected - with a ``TypeError``. - """ - if isinstance(arg, str): - return arg - - # Duck-type: works for both ibis and xorq Deferred objects - resolver = getattr(arg, "_resolver", None) - if resolver is None: - raise TypeError( - f"Expected a string name or Deferred expression (_.name), got {type(arg).__name__}" - ) - - obj = getattr(resolver, "obj", None) - - # Try attribute access first (_.name -> Attr resolver with .name) - name_wrapper = getattr(resolver, "name", None) - - # Fall back to getitem access (_["name"] -> Item resolver with .indexer) - if name_wrapper is None: - name_wrapper = getattr(resolver, "indexer", None) - - if name_wrapper is None or obj is None: - raise TypeError( - f"Only simple Deferred expressions like _.name or _['name'] are supported " - f"as positional arguments, got: {arg!r}" - ) - - # Reject chained access like _.a.b (obj would itself have an .obj attr) - if getattr(obj, "obj", None) is not None: - raise TypeError( - f"Only simple Deferred expressions like _.name or _['name'] are supported " - f"as positional arguments, got: {arg!r}" - ) - - # Attr.name / Item.indexer is a Just wrapper; unwrap via .value - raw_name = getattr(name_wrapper, "value", name_wrapper) - if not isinstance(raw_name, str): - raise TypeError(f"Could not extract string name from Deferred expression: {arg!r}") - - return raw_name - - -def _normalize_join_predicate(on): - """Normalize a join predicate to a two-argument callable. - - Accepts: - - ``str`` – equi-join on a column present in both sides - - ``Deferred`` (``_.col``) – same, after extracting the name - - ``list[str | Deferred]`` – compound equi-join on multiple columns - - ``callable`` (non-Deferred) – returned as-is (existing lambda API) - - ``None`` – returned as-is (for cross joins) - """ - if on is None: - return on - - if isinstance(on, str): - name = on - return lambda left, right: getattr(left, name) == getattr(right, name) - - if _is_deferred(on): - name = _normalize_to_name(on) - return lambda left, right: getattr(left, name) == getattr(right, name) - - if isinstance(on, (list, tuple)): - names = [_normalize_to_name(item) for item in on] - if len(names) == 1: - name = names[0] - return lambda left, right: getattr(left, name) == getattr(right, name) - - def _compound_predicate(left, right): - from functools import reduce - from operator import and_ - - preds = [getattr(left, n) == getattr(right, n) for n in names] - return reduce(and_, preds) - - return _compound_predicate - - if callable(on): - return on - - raise TypeError( - f"join `on` must be a string, Deferred (_.col), list of strings/Deferred, " - f"or a callable, got {type(on).__name__}" - ) +from ._normalize import _normalize_join_predicate, _normalize_to_name # noqa: E402, F401 if TYPE_CHECKING: - from .expr import ( + from ..expr import ( SemanticFilter, SemanticGroupBy, SemanticLimit, @@ -200,117 +104,12 @@ def _compound_predicate(left, right): ) -def _patch_xorq_sortkey_compat(): - """Register a map_ibis handler so ibis SortKey → xorq SortKey. - - ibis 11 uses ``SortKey.expr``, ibis 12 renamed it to ``SortKey.arg``, - while xorq's vendored ibis keeps ``SortKey.expr``. Handle both. - """ - from ibis.expr.operations.sortkeys import SortKey as IbisSortKey - - from ._xorq import SortKey as XorqSortKey, map_ibis - - if IbisSortKey in map_ibis.registry: - return # already patched - - @map_ibis.register(IbisSortKey) - def _map_sort_key(val, kwargs=None): - # ibis 12 uses .arg, ibis 11 uses .expr - sort_expr = getattr(val, "arg", None) or getattr(val, "expr") - return XorqSortKey( - expr=map_ibis(sort_expr, None), - ascending=val.ascending, - nulls_first=val.nulls_first, - ) - - -def _ensure_xorq_table(table): - """Convert plain ibis Table to xorq-vendored ibis if possible. - - This is the single boundary between user-supplied ibis tables and - BSL's internal xorq representation. ``SemanticModel`` calls it once - at construction so internal code paths can assume xorq tables when - the backend is supported, and a plain ibis fallback otherwise. - - Falls back to returning the plain ibis table when the backend is not - supported by xorq (e.g. Databricks). Idempotent: calling it on a - xorq-vendored table is a cheap no-op. - """ - _patch_xorq_sortkey_compat() - if "xorq.vendor.ibis" not in type(table).__module__: - try: - from ._xorq import from_ibis - - return from_ibis(table) - except Exception: - # Backend isn't supported by xorq's map_ibis registry (e.g. - # Databricks). Fall back so plain-ibis paths can still execute. - logger.debug( - "from_ibis failed for %s; using plain ibis table", - type(table).__module__, - exc_info=True, - ) - return table - return table - - -def _rebind_to_backend(expr, target_backend): - """Rebind every ``DatabaseTable`` op in *expr* to *target_backend*. - - Low-level primitive shared with ``serialization.reconstruct``. - No-op on plain ibis expressions or when xorq is unavailable for any - reason; callers must pass a xorq-vendored ``target_backend``. - """ - try: - from ._xorq import relations as xorq_rel - except ImportError: - return expr - - def _recreate(op, _kwargs, **overrides): - kwargs = dict(zip(op.__argnames__, op.__args__, strict=False)) - if _kwargs: - kwargs.update(_kwargs) - kwargs.update(overrides) - return op.__recreate__(kwargs) - - def replacer(op, _kwargs): - if isinstance(op, xorq_rel.DatabaseTable) and op.source is not target_backend: - return _recreate(op, _kwargs, source=target_backend) - if _kwargs: - return _recreate(op, _kwargs) - return op - - return expr.op().replace(replacer).to_expr() - - -def _rebind_to_canonical_backend(expr): - """Rebind divergent ``DatabaseTable`` backends in *expr* to share one. - - ``from_ibis()`` creates a distinct ``Backend`` per call, so expressions - built by composing separately-converted tables contain multiple - backends. Picking the first ``DatabaseTable``'s source as canonical - and rebinding the rest eliminates "Multiple backends found" errors - at execution time. - - No-op on plain ibis expressions (not xorq-vendored). - """ - try: - from ._xorq import relations as xorq_rel, walk_nodes - except ImportError: - return expr - - try: - db_tables = list(walk_nodes((xorq_rel.DatabaseTable,), expr)) - except Exception: - # walk_nodes can't traverse plain ibis trees; treat as no-op. - logger.debug("walk_nodes failed on plain ibis expr", exc_info=True) - return expr - - canonical = db_tables[0].source if db_tables else None - if canonical is None: - return expr - - return _rebind_to_backend(expr, canonical) +from ._xorq_compat import ( # noqa: E402 + _ensure_xorq_table, + _patch_xorq_sortkey_compat, + _rebind_to_backend, + _rebind_to_canonical_backend, +) def _to_untagged(source: Any) -> ir.Table: @@ -318,7 +117,7 @@ def _to_untagged(source: Any) -> ir.Table: def _semantic_table(*args, **kwargs) -> SemanticTable: - from .expr import SemanticModel + from ..expr import SemanticModel return SemanticModel(*args, **kwargs) @@ -327,138 +126,7 @@ def _unwrap(wrapped: Any) -> Any: return wrapped.unwrap if isinstance(wrapped, _CallableWrapper) else wrapped -def _collect_chain(op: Relation) -> list[Relation]: - """Walk op.source (or op.left for joins) back to root, return list from root to current.""" - chain = [op] - current = op - while True: - if hasattr(current, "source") and current.source is not None: - chain.append(current.source) - current = current.source - elif hasattr(current, "left") and current.left is not None: - chain.append(current.left) - current = current.left - else: - break - chain.reverse() - return chain - - -def _format_op_summary(op: Relation) -> str: - """Return a one-line summary string for a non-root semantic op.""" - # Import here to avoid circular imports at module level - cls = type(op).__name__ - - if isinstance(op, SemanticFilterOp): - predicate = object.__getattribute__(op, "predicate") - pred_name = "" - if hasattr(predicate, "__name__"): - pred_name = predicate.__name__ - elif hasattr(predicate, "unwrap"): - unwrapped = predicate.unwrap - if hasattr(unwrapped, "__name__"): - pred_name = unwrapped.__name__ - return f"Filter(\u03bb {pred_name})" - - if isinstance(op, SemanticMutateOp): - post = object.__getattribute__(op, "post") - cols = list(post.keys()) - return f"Mutate({', '.join(cols)})" - - if isinstance(op, SemanticGroupByOp): - keys = object.__getattribute__(op, "keys") - return f"GroupBy({', '.join(keys)})" - - if isinstance(op, SemanticAggregateOp): - aggs = object.__getattribute__(op, "aggs") - agg_names = list(aggs.keys()) - return f"Aggregate({', '.join(agg_names)})" - - if isinstance(op, SemanticOrderByOp): - keys = object.__getattribute__(op, "keys") - key_strs = [k if isinstance(k, str) else repr(k) for k in keys] - return f"OrderBy({', '.join(key_strs)})" - - if isinstance(op, SemanticLimitOp): - return f"Limit({op.n})" - - if isinstance(op, SemanticProjectOp): - fields = object.__getattribute__(op, "fields") - return f"Project({', '.join(fields)})" - - if isinstance(op, SemanticUnnestOp): - column = object.__getattribute__(op, "column") - return f"Unnest({column})" - - if isinstance(op, SemanticJoinOp): - how = object.__getattribute__(op, "how") - right = object.__getattribute__(op, "right") - right_name = "" - if isinstance(right, SemanticTableOp): - right_name = object.__getattribute__(right, "name") or "" - if not right_name: - # Try to find a root name from right side - roots = _find_all_root_models(right) - if roots: - right_name = object.__getattribute__(roots[0], "name") or "" - if right_name: - return f"Join({how}, right={right_name})" - return f"Join({how})" - - if isinstance(op, SemanticIndexOp): - parts = [] - selector = object.__getattribute__(op, "selector") - by = object.__getattribute__(op, "by") - sample = object.__getattribute__(op, "sample") - if selector is not None: - if isinstance(selector, tuple): - parts.append(", ".join(selector)) - else: - parts.append(str(selector)) - if by is not None: - parts.append(f"by={by}") - if sample is not None: - parts.append(f"sample={sample}") - return f"Index({', '.join(parts)})" - - # Fallback for unknown op types - return cls.replace("Semantic", "").replace("Op", "") - - -def _format_root(root_op: SemanticTableOp) -> str: - """Format a SemanticTableOp root using the fmt registry from format.py.""" - from boring_semantic_layer.format import fmt - - try: - return fmt(root_op) - except Exception: - # Fallback if format module isn't available - name = object.__getattribute__(root_op, "name") - return f"SemanticTable: {name}" if name else "SemanticTable" - - -def _semantic_repr(op: Relation) -> str: - chain = _collect_chain(op) - - # Find the root (first element should be a SemanticTableOp) - root = chain[0] - if isinstance(root, SemanticTableOp): - lines = [_format_root(root)] - else: - # Fallback: no SemanticTableOp root found - from ibis.expr.format import pretty - - try: - return pretty(op) - except Exception: - return object.__repr__(op) - - # Append pipeline steps - for step in chain[1:]: - if not isinstance(step, SemanticTableOp): - lines.append(f"-> {_format_op_summary(step)}") - - return "\n".join(lines) +from ._format import _collect_chain, _format_op_summary, _format_root, _semantic_repr # noqa: E402, F401 def _make_schema(fields_dict: dict[str, str]): @@ -482,7 +150,7 @@ def _resolve_expr(expr: Deferred | Callable | Any, scope: ir.Table) -> ir.Value: scope_is_xorq = "xorq.vendor.ibis" in scope_module if result_is_regular_ibis and scope_is_xorq: - from ._xorq import from_ibis + from .._xorq import from_ibis result = from_ibis(result) @@ -606,3795 +274,1759 @@ def _classify_dependencies( } -@frozen -class _CallableWrapper: - """Hashable wrapper for Callable and Deferred. - - Both raw callables (lambda) and user Deferred (_.foo) are not hashable - and cannot be stored in FrozenDict. This wrapper provides hashability - using identity-based hashing. - """ - - _fn: Any - - def __call__(self, *args, **kwargs): - return self._fn(*args, **kwargs) - - def __hash__(self): - # should this be dask.base.tokenize()? - return hash(id(self._fn)) - - @property - def unwrap(self): - return self._fn - - -def _ensure_wrapped(fn: Any) -> _CallableWrapper: - """Wrap Callable or Deferred for hashability.""" - return fn if isinstance(fn, _CallableWrapper) else _CallableWrapper(fn) - - -def _infer_unnest(fn: Callable, table: Any) -> tuple[str, ...]: - """Infer required unnest operations from the table. +from ._callable import _CallableWrapper, _ensure_wrapped, _infer_unnest # noqa: E402 - Examples: - to_semantic_table(tbl).with_measures(...) -> () # Session level - to_semantic_table(tbl).unnest("hits").with_measures(...) -> ("hits",) - unnested.unnest("product").with_measures(...) -> ("product",) - """ - from .expr import SemanticUnnest - - if isinstance(table, SemanticUnnest): - op = table.op() - # SemanticUnnestOp always has column attribute - return (op.column,) - - return () - - -def _extract_measure_metadata( - fn_or_expr: Any, -) -> tuple[Any, str | None, tuple, Mapping[str, Any]]: - """Extract metadata from various measure representations.""" - if isinstance(fn_or_expr, dict): - return ( - fn_or_expr["expr"], - fn_or_expr.get("description"), - tuple(fn_or_expr.get("requires_unnest", [])), - dict(fn_or_expr.get("metadata") or {}), - ) - elif isinstance(fn_or_expr, Measure): - return ( - fn_or_expr.expr, - fn_or_expr.description, - fn_or_expr.requires_unnest, - dict(fn_or_expr.metadata), - ) - else: - return (fn_or_expr, None, (), {}) - - -_AGG_METHODS = frozenset({"sum", "mean", "avg", "count", "min", "max"}) - - -def _is_calculated_measure(val: Any) -> bool: - # A MethodCall with an aggregation method on a MeasureRef is a base measure: - # the column name matched a known measure name in MeasureScope, but the user - # is really defining a column aggregation (e.g. lambda t: t.flight_count.sum()). - if ( - isinstance(val, MethodCall) - and val.method in _AGG_METHODS - and isinstance(val.receiver, MeasureRef) - ): - return False - return isinstance(val, MeasureRef | AllOf | BinOp | MethodCall | int | float) - - -def _matches_aggregation_pattern(measure_expr, agg_expr, tbl): - if not isinstance(agg_expr, AggregationExpr): - return Success(False) - - @curry - def evaluate_in_scope(tbl, expr): - """Evaluate measure expression in a ColumnScope.""" - scope = ColumnScope(_tbl=tbl) - return ( - expr.resolve(scope) if _is_deferred(expr) else expr(scope) if callable(expr) else expr - ) - - @curry - def has_matching_operation(agg_expr, result): - """Check if the operation matches the expected aggregation. - - All our supported aggregations (Sum, Mean, Count, Min, Max) are ibis operations. - """ - op_name = type(result.op()).__name__.lower() - expected_op = "avg" if agg_expr.operation.lower() == "mean" else agg_expr.operation.lower() - - return expected_op in op_name - - @curry - def has_matching_column(agg_expr, result): - """Check if result's operation references the expected column. - - All supported aggregation operations (Sum, Mean, Count, Min, Max) have: - - args[0]: Field operation with .name attribute - - args[1]: Optional where clause (typically None) - """ - op = result.op() - if not isinstance(op.args[0], Field): - return False +from ._measure_helpers import ( # noqa: E402 + _AGG_METHODS, + _build_json_definition, + _classify_measure, + _extract_measure_metadata, + _find_matching_measure, + _format_column_error, + _is_calculated_measure, + _make_base_measure, + _matches_aggregation_pattern, +) - return op.args[0].name == agg_expr.column - def matches_pattern(result): - """Check if result matches both operation and column.""" - return has_matching_operation(agg_expr, result) and has_matching_column(agg_expr, result) +from ._values import ( # noqa: E402 + Dimension, + Measure, + _DimPrefixProxy, + _DimensionTableProxy, + _is_deferred, +) - return safe(lambda: evaluate_in_scope(tbl, measure_expr))().map(matches_pattern) -def _find_matching_measure(agg_expr, known_measures: dict, tbl): - """Find a measure that matches the aggregation expression pattern. - Returns Maybe[str] using functional patterns. - """ - if not isinstance(agg_expr, AggregationExpr): - return Nothing +from ._basic_ops import ( # noqa: E402, F401 + SemanticFilterOp, + SemanticGroupByOp, + SemanticProjectOp, + SemanticTableOp, + _build_select_or_aggregate, + _classify_fields, + _evaluate_measures_with_unnesting, + _process_nested_access_marker, +) - @curry - def matches_pattern(agg_expr, tbl, measure_obj): - """Check if measure matches the aggregation pattern. - All measure_obj values are Measure instances with an expr attribute. - """ - result = _matches_aggregation_pattern(measure_obj.expr, agg_expr, tbl) - return result.value_or(False) +@frozen +class _MeasureSpec: + name: str + kind: str # 'agg' or 'calc' + value: Any - for measure_name, measure_obj in known_measures.items(): - if matches_pattern(agg_expr, tbl, measure_obj): - return Some(measure_name) - return Nothing +@frozen +class _AggregationPlan: + agg_specs: FrozenDict[str, Callable] + calc_specs: FrozenDict[str, Any] + requested_measures: tuple[str, ...] + group_by_cols: tuple[str, ...] -def _make_base_measure( +def _resolve_aggregation_exprs( expr: Any, - description: str | None, - requires_unnest: tuple, - metadata: Mapping[str, Any] | None = None, -) -> Measure: - """Create a base measure with proper callable wrapping using functional patterns.""" - + merged_base_measures: dict, + merged_calc_measures: dict, + tbl: ir.Table, +) -> Any: @curry - def apply_aggregation(operation: str, column): - """Apply aggregation operation to a column using functional dispatch.""" - operations = { - "sum": lambda c: c.sum(), - "mean": lambda c: c.mean(), - "avg": lambda c: c.mean(), - "count": lambda c: c.count(), - "min": lambda c: c.min(), - "max": lambda c: c.max(), - } - - return ( - Maybe.from_optional(operations.get(operation)) - .map(lambda fn: fn(column)) - .value_or( - (_ for _ in ()).throw(ValueError(f"Unknown aggregation operation: {operation}")) - ) - ) + def find_in_calc_measures(expr, calc_measures): + for calc_name, calc_expr in calc_measures.items(): + if isinstance(calc_expr, AggregationExpr) and ( + calc_expr.column == expr.column and calc_expr.operation == expr.operation + ): + return Some(calc_name) + return Nothing - @curry - def evaluate_expr(expr, scope): - """Evaluate expression in given scope.""" - return ( - expr.resolve(scope) if _is_deferred(expr) else expr(scope) if callable(expr) else expr + def resolve_aggregation(agg_expr): + matched = _find_matching_measure(agg_expr, merged_base_measures, tbl) + return matched.map(MeasureRef).value_or( + find_in_calc_measures(agg_expr, merged_calc_measures).map(MeasureRef).value_or(agg_expr) ) - def convert_aggregation_expr(t, agg_expr: AggregationExpr): - """Convert AggregationExpr to ibis expression.""" - if agg_expr.operation == "count": - result = t.count() - else: - result = apply_aggregation(agg_expr.operation, t[agg_expr.column]) - - for method_name, args, kwargs_tuple in agg_expr.post_ops: - result = getattr(result, method_name)(*args, **dict(kwargs_tuple)) - - return result - - raw_expr = expr._fn if isinstance(expr, _CallableWrapper) else expr - if isinstance(expr, AggregationExpr): - - def wrapped_expr(t): - """Convert AggregationExpr to ibis expression.""" - return convert_aggregation_expr(t, expr) - - return Measure( - expr=wrapped_expr, - description=description, - requires_unnest=requires_unnest, - original_expr=raw_expr, - metadata=dict(metadata or {}), + return resolve_aggregation(expr) + elif isinstance(expr, MethodCall): + return MethodCall( + receiver=_resolve_aggregation_exprs( + expr.receiver, merged_base_measures, merged_calc_measures, tbl + ), + method=expr.method, + args=expr.args, + kwargs=expr.kwargs, ) - - if callable(expr): - - def wrapped_expr(t): - """Wrapped expression that handles AggregationExpr conversion.""" - scope = ColumnScope(_tbl=t) - result = evaluate_expr(expr, scope) - - if isinstance(result, AggregationExpr): - return convert_aggregation_expr(t, result) - return result - - return Measure( - expr=wrapped_expr, - description=description, - requires_unnest=requires_unnest, - original_expr=raw_expr, - metadata=dict(metadata or {}), + elif isinstance(expr, BinOp): + return BinOp( + op=expr.op, + left=_resolve_aggregation_exprs( + expr.left, merged_base_measures, merged_calc_measures, tbl + ), + right=_resolve_aggregation_exprs( + expr.right, merged_base_measures, merged_calc_measures, tbl + ), ) + elif isinstance(expr, AllOf) and isinstance(expr.ref, AggregationExpr): + return AllOf(resolve_aggregation(expr.ref)) else: - return Measure( - expr=lambda t, fn=expr: evaluate_expr(fn, ColumnScope(_tbl=t)), - description=description, - requires_unnest=requires_unnest, - original_expr=raw_expr, - metadata=dict(metadata or {}), - ) - - -def _classify_measure( - fn_or_expr: Any, scope: Any, measure_name: str | None = None -) -> tuple[str, Any]: - """Classify measure as 'calc' or 'base' with appropriate handling.""" - from .measure_scope import validate_calc_ast + return expr - expr, description, requires_unnest, metadata = _extract_measure_metadata(fn_or_expr) - resolved = safe(lambda: _resolve_expr(expr, scope))().map( - lambda val: ("calc", val) if _is_calculated_measure(val) else None - ) +def _create_measure_spec( + name: str, + fn_wrapped: Any, + scope: Any, + is_post_agg: bool, + merged_base_measures: dict, + merged_calc_measures: dict, + tbl: ir.Table, +) -> _MeasureSpec: + fn = _unwrap(fn_wrapped) + val = _resolve_expr(fn, scope) + val = _resolve_aggregation_exprs(val, merged_base_measures, merged_calc_measures, tbl) - if isinstance(resolved, Success) and resolved.unwrap() is not None: - kind, value = resolved.unwrap() - validate_calc_ast(value, measure_name) - return (kind, value) - - if not requires_unnest and callable(expr): - # All scopes (MeasureScope, ColumnScope) have tbl attribute - table = scope.tbl - inferred_unnest = _infer_unnest(expr, table) - requires_unnest = requires_unnest or inferred_unnest - - return ("base", _make_base_measure(expr, description, requires_unnest, metadata)) - - -def _build_json_definition( - dims_dict: dict, - meas_dict: dict, - name: str | None = None, - description: str | None = None, -) -> dict: - result = { - "dimensions": {n: spec.to_json() for n, spec in dims_dict.items()}, - "measures": {n: spec.to_json() for n, spec in meas_dict.items()}, - "entity_dimensions": {n: spec.to_json() for n, spec in dims_dict.items() if spec.is_entity}, - "event_timestamp": { - n: spec.to_json() for n, spec in dims_dict.items() if spec.is_event_timestamp - }, - "time_dimensions": { - n: spec.to_json() for n, spec in dims_dict.items() if spec.is_time_dimension - }, - "name": name, - } - if description is not None: - result["description"] = description - return result + if is_post_agg: + return _MeasureSpec(name=name, kind="agg", value=fn) + if isinstance(val, MeasureRef): + ref_name = val.name + if ref_name in merged_calc_measures: + calc_expr = merged_calc_measures[ref_name] + resolved = _resolve_aggregation_exprs( + calc_expr, merged_base_measures, merged_calc_measures, tbl + ) + return _MeasureSpec(name=name, kind="calc", value=resolved) + elif ref_name in merged_base_measures: + return _MeasureSpec(name=name, kind="agg", value=merged_base_measures[ref_name]) + else: + return _MeasureSpec(name=name, kind="calc", value=val) -def _format_column_error(e: AttributeError, table: ir.Table) -> str: - """Format a helpful error message for missing column errors.""" - # Extract the column name from the error - match = re.search(r"has no attribute ['\"]([^'\"]+)['\"]", str(e)) - missing_col = match.group(1) if match else "unknown" + if isinstance(val, AllOf | BinOp | MethodCall | int | float): + return _MeasureSpec(name=name, kind="calc", value=val) - # Get available columns - available_cols = list(table.columns) if hasattr(table, "columns") else [] + return _MeasureSpec(name=name, kind="agg", value=fn) - # Build error message - parts = [f"Dimension expression references non-existent column '{missing_col}'."] - if len(available_cols) > 20: - parts.append(f"Table has {len(available_cols)} columns. First 15: {available_cols[:15]}") - elif available_cols: - parts.append(f"Available columns: {available_cols}") +def _make_agg_callable(measure: Any) -> Callable: + if _is_deferred(measure): + return lambda t: measure.resolve(ColumnScope(_tbl=t)) + elif callable(measure): + return lambda t: measure(ColumnScope(_tbl=t)) else: - parts.append(f"No columns available in {type(table).__name__} object") - - # Suggest similar column names - suggestions = get_close_matches(missing_col, available_cols, n=3, cutoff=0.6) - if suggestions: - parts[-1] += f". Did you mean: {suggestions}?" - - # Add helpful tip - example = suggestions[0] if suggestions else "column_name" - parts.append( - f"\n\nTip: Check that your dimension expression uses the correct column name. " - f"For example: lambda t: t.{example}" - ) - - return " ".join(parts) - - -class _DimPrefixProxy: - """Resolves ``proxy.column`` to ``dims["prefix.column"](table)``.""" - - __slots__ = ("_tbl", "_dims", "_prefix") + return lambda t: measure(t) - def __init__(self, tbl, dims: dict, prefix: str): - object.__setattr__(self, "_tbl", tbl) - object.__setattr__(self, "_dims", dims) - object.__setattr__(self, "_prefix", prefix) - def __getattr__(self, name: str): - full_name = f"{self._prefix}.{name}" - if full_name in self._dims: - return self._dims[full_name](self._tbl) - raise AttributeError( - f"No dimension '{full_name}' found. " - f"Available dimensions with prefix '{self._prefix}.': " - f"{[k for k in self._dims if k.startswith(self._prefix + '.')]}" - ) +def _collect_all_measure_refs(calc_exprs) -> frozenset[str]: + all_refs = set() + for expr in calc_exprs: + _collect_measure_refs(expr, all_refs) + return frozenset(all_refs) -class _DimensionTableProxy: - """Proxy that wraps an ibis table to support model-prefix navigation. +def _expand_calc_measure_refs( + expr: Any, + merged_base_measures: dict, + merged_calc_measures: dict, + tbl: ir.Table, + cache: dict[str, Any] | None = None, + path: tuple[str, ...] = (), +) -> Any: + """Inline calc-measure references transitively for multi-layer formulas.""" + cache = {} if cache is None else cache - Allows dimension lambdas like ``lambda t: t.flights.carrier`` to work on - joined tables by resolving ``t.flights.carrier`` through the merged - dimension map (``dims["flights.carrier"](table)``). - """ - - __slots__ = ("_tbl", "_dims") - - def __init__(self, tbl, dims: dict): - object.__setattr__(self, "_tbl", tbl) - object.__setattr__(self, "_dims", dims) + def _lift_to_allof(value: Any) -> Any: + """Lift an expanded expression into totals-space via AllOf on refs.""" + if isinstance(value, MeasureRef): + return AllOf(value) + if isinstance(value, BinOp): + return BinOp( + op=value.op, + left=_lift_to_allof(value.left), + right=_lift_to_allof(value.right), + ) + if isinstance(value, MethodCall): + return MethodCall( + receiver=_lift_to_allof(value.receiver), + method=value.method, + args=value.args, + kwargs=value.kwargs, + ) + return value - def __getattr__(self, name: str): - prefix = f"{name}." - if any(k.startswith(prefix) for k in self._dims): - return _DimPrefixProxy(self._tbl, self._dims, name) - return getattr(self._tbl, name) + if isinstance(expr, MeasureRef): + ref_name = expr.name + if ref_name not in merged_calc_measures: + return expr + if ref_name in cache: + return cache[ref_name] + if ref_name in path: + cycle = " -> ".join((*path, ref_name)) + raise ValueError(f"Circular calculated measure dependency detected: {cycle}") - def __getitem__(self, name: str): - if name in self._dims: - return self._dims[name](self._tbl) - return self._tbl[name] + resolved = _resolve_aggregation_exprs( + merged_calc_measures[ref_name], merged_base_measures, merged_calc_measures, tbl + ) + expanded = _expand_calc_measure_refs( + resolved, + merged_base_measures, + merged_calc_measures, + tbl, + cache, + (*path, ref_name), + ) + cache[ref_name] = expanded + return expanded - @property - def columns(self): - return self._tbl.columns - - -@frozen(kw_only=True, slots=True) -class Dimension: - expr: Callable[[ir.Table], ir.Value] | Deferred - description: str | None = None - is_entity: bool = False - is_time_dimension: bool = False - is_event_timestamp: bool = False - smallest_time_grain: str | None = None - derived_dimensions: tuple[str, ...] = () - metadata: Mapping[str, Any] = field(factory=dict, eq=False, hash=False) - - def __call__(self, table: ir.Table, _dims: dict | None = None) -> ir.Value: - try: - return self.expr.resolve(table) if _is_deferred(self.expr) else self.expr(table) - except AttributeError as e: - # Retry with a prefix-aware proxy for joined tables where - # model prefixes are used (e.g., lambda t: t.flights.carrier) - if _dims and not _is_deferred(self.expr) and callable(self.expr): - try: - proxy = _DimensionTableProxy(table, _dims) - return self.expr(proxy) - except AttributeError as proxy_err: - # Preserve explicit prefix-proxy errors (e.g. missing - # "model.field") to avoid silent fallback to unprefixed - # columns, but keep normal missing-column errors on the - # original table so they get the helpful formatter below. - if str(proxy_err).startswith("No dimension '"): - raise - except Exception: - pass - # Provide helpful error for missing columns - if "'Table' object has no attribute" in str( - e - ) or "'Join' object has no attribute" in str(e): - raise AttributeError(_format_column_error(e, table)) from e - raise - - def to_json(self) -> Mapping[str, Any]: - base = {"description": self.description} - if self.is_entity: - base["is_entity"] = True - if self.is_event_timestamp: - base["is_event_timestamp"] = True - if self.is_time_dimension: - base["smallest_time_grain"] = self.smallest_time_grain - if self.derived_dimensions: - base["derived_dimensions"] = list(self.derived_dimensions) - if self.metadata: - base.update(self.metadata) - return base - - def __hash__(self) -> int: - return hash( - ( - self.description, - self.is_entity, - self.is_event_timestamp, - self.is_time_dimension, - self.smallest_time_grain, - self.derived_dimensions, + if isinstance(expr, MethodCall): + return MethodCall( + receiver=_expand_calc_measure_refs( + expr.receiver, merged_base_measures, merged_calc_measures, tbl, cache, path ), + method=expr.method, + args=expr.args, + kwargs=expr.kwargs, ) - -@frozen(kw_only=True, slots=True) -class Measure: - expr: Callable[[ir.Table], ir.Value] | Deferred - description: str | None = None - requires_unnest: tuple[str, ...] = () # Internal: Arrays that must be unnested - original_expr: Any = field(default=None, eq=False, hash=False) - metadata: Mapping[str, Any] = field(factory=dict, eq=False, hash=False) - - def __call__(self, table: ir.Table) -> ir.Value: - return self.expr.resolve(table) if _is_deferred(self.expr) else self.expr(table) - - @property - def locality(self) -> str | None: - """Derive locality from requires_unnest (most nested level).""" - return self.requires_unnest[-1] if self.requires_unnest else None - - def to_json(self) -> Mapping[str, Any]: - base = {"description": self.description} - if self.locality: - base["locality"] = self.locality - if self.requires_unnest: - base["requires_unnest"] = list(self.requires_unnest) - if self.metadata: - base.update(self.metadata) - return base - - def __hash__(self) -> int: - return hash((self.description, self.requires_unnest)) - - -class SemanticTableOp(Relation): - """Relation with semantic metadata (dimensions and measures). - - Stores ir.Table expression directly to avoid .op() → .to_expr() conversions. - - Note: Accepts both regular ibis.Table and xorq's vendored ibis.Table. - Regular ibis tables are automatically converted to xorq in __init__. - """ - - table: Any # Accepts both ir.Table and regular ibis.expr.types.Table - dimensions: FrozenDict[str, Dimension] - measures: FrozenDict[str, Measure] - calc_measures: FrozenDict[str, Any] - name: str | None = None - description: str | None = None - _source_join: Any = field( - default=None, repr=False - ) # Track if this wraps a join (SemanticJoinOp) for optimization - - def __init__( - self, - table: ir.Table, - dimensions: dict[str, Dimension] | FrozenDict[str, Dimension], - measures: dict[str, Measure] | FrozenDict[str, Measure], - calc_measures: dict[str, Any] | FrozenDict[str, Any], - name: str | None = None, - description: str | None = None, - _source_join: Any = None, - ) -> None: - # Accept both regular ibis and xorq tables without conversion - # This allows using regular ibis by default, xorq only when provided - super().__init__( - table=table, - dimensions=FrozenDict(dimensions) - if not isinstance(dimensions, FrozenDict) - else dimensions, - measures=FrozenDict(measures) if not isinstance(measures, FrozenDict) else measures, - calc_measures=FrozenDict(calc_measures) - if not isinstance(calc_measures, FrozenDict) - else calc_measures, - name=name, - description=description, - _source_join=_source_join, + if isinstance(expr, BinOp): + return BinOp( + op=expr.op, + left=_expand_calc_measure_refs( + expr.left, merged_base_measures, merged_calc_measures, tbl, cache, path + ), + right=_expand_calc_measure_refs( + expr.right, merged_base_measures, merged_calc_measures, tbl, cache, path + ), ) - def __repr__(self) -> str: - return _semantic_repr(self) + if isinstance(expr, AllOf): + if isinstance(expr.ref, MeasureRef): + expanded_ref = _expand_calc_measure_refs( + expr.ref, merged_base_measures, merged_calc_measures, tbl, cache, path + ) + if isinstance(expanded_ref, MeasureRef): + return AllOf(expanded_ref) + return _lift_to_allof(expanded_ref) + return expr - @property - def values(self) -> FrozenOrderedDict[str, Any]: - dims = self.get_dimensions() - measures = self.get_measures() - calc_measures = self.get_calculated_measures() - # Build enriched table with all dimensions resolved (handles derived deps) - enriched = _mutate_dimensions_with_dependencies(self.table, dims.keys(), dims) - base_values = { - **{col: self.table[col].op() for col in self.table.columns}, - **{name: enriched[name].op() for name in dims}, - **{name: fn(enriched).op() for name, fn in measures.items()}, - } - # Resolve calculated measure types via a dummy table with base measure dtypes. - # ``infer_calc_dtype`` mirrors the AggregationExpr rewrite from - # ``compile_grouped_with_all`` so calc measures with inline aggregations - # (e.g. ``AllOf(AggregationExpr)``) round-trip through type inference. - if calc_measures: - from .compile_all import _get_ibis_module, infer_calc_dtype - - measure_schema = { - name: base_values[name].dtype for name in measures if name in base_values - } - ibis_module = _get_ibis_module(enriched) - for name, expr in calc_measures.items(): - try: - compiled = infer_calc_dtype( - expr, measure_schema, enriched, ibis_module - ) - base_values[name] = compiled.op() - except Exception as e: - # Joined models with dotted column names, calc measures - # whose inline aggregations don't apply to the dummy schema, - # etc. Type info is best-effort; surface for debugging. - logger.debug( - "calc-measure type inference failed for %r: %s", name, e - ) - return FrozenOrderedDict(base_values) + return expr - @property - def schema(self): - fields_dict = {name: str(v.dtype) for name, v in self.values.items()} - return _make_schema(fields_dict) - @property - def json_definition(self) -> Mapping[str, Any]: - return _build_json_definition( - self.get_dimensions(), - self.get_measures(), - self.name, - self.description, +def _build_aggregation_plan( + aggs: dict, + keys: tuple, + scope: Any, + is_post_agg: bool, + merged_base_measures: dict, + merged_calc_measures: dict, + tbl: ir.Table, +) -> _AggregationPlan: + specs = [ + _create_measure_spec( + name, fn, scope, is_post_agg, merged_base_measures, merged_calc_measures, tbl ) + for name, fn in aggs.items() + ] - @property - def _dims(self) -> dict[str, Dimension]: - return dict(self.get_dimensions()) - - @property - def _base_measures(self) -> dict[str, Measure]: - return dict(self.get_measures()) - - @property - def _calc_measures(self) -> dict[str, Any]: - return dict(self.get_calculated_measures()) + agg_specs_list = [s for s in specs if s.kind == "agg"] + calc_specs_list = [s for s in specs if s.kind == "calc"] - def get_measures(self) -> Mapping[str, Measure]: - """Get dictionary of base measures with metadata.""" - return object.__getattribute__(self, "measures") + agg_specs = FrozenDict({s.name: _make_agg_callable(s.value) for s in agg_specs_list}) + calc_specs = FrozenDict({s.name: s.value for s in calc_specs_list}) - def get_dimensions(self) -> Mapping[str, Dimension]: - """Get dictionary of dimensions with metadata.""" - return object.__getattribute__(self, "dimensions") + calc_cache: dict[str, Any] = {} + expanded_calc_specs = FrozenDict( + { + name: _expand_calc_measure_refs( + expr, + merged_base_measures, + merged_calc_measures, + tbl, + cache=calc_cache, + path=(name,), + ) + for name, expr in calc_specs.items() + } + ) - def get_calculated_measures(self) -> Mapping[str, Any]: - """Get dictionary of calculated measures with metadata.""" - return self.calc_measures + referenced = _collect_all_measure_refs(expanded_calc_specs.values()) + additional_aggs = { + ref: _make_agg_callable(merged_base_measures[ref]) + for ref in referenced + if ref not in agg_specs and ref in merged_base_measures + } - def get_graph(self) -> dict[str, dict[str, Any]]: - from .graph_utils import build_dependency_graph + final_agg_specs = FrozenDict({**agg_specs, **additional_aggs}) - return build_dependency_graph( - self.get_dimensions(), - self.get_measures(), - self.get_calculated_measures(), - self.table, - ) + return _AggregationPlan( + agg_specs=final_agg_specs, + calc_specs=expanded_calc_specs, + requested_measures=tuple(aggs.keys()), + group_by_cols=tuple(keys), + ) - def __getattribute__(self, name: str): - """Override attribute access to return tuples for dimensions/measures. - This provides a cleaner API where .dimensions returns ('dim1', 'dim2') - instead of the full FrozenDict. Use get_dimensions() to get the full dict. - """ - # For special/internal attributes (dunder methods), use default behavior - # This is critical for xorq's vendored ibis which uses __precomputed_hash__, etc. - if name.startswith("__") and name.endswith("__"): - return object.__getattribute__(self, name) - - # Custom behavior for dimensions and measures - if name == "dimensions": - dims = object.__getattribute__(self, "dimensions") - return tuple(dims.keys()) - if name == "measures": - base_meas = object.__getattribute__(self, "measures") - calc_meas = object.__getattribute__(self, "calc_measures") - return tuple(base_meas.keys()) + tuple(calc_meas.keys()) - - # Default behavior for everything else - return object.__getattribute__(self, name) +# --------------------------------------------------------------------------- +# Pre-aggregation helpers (fan-out / chasm trap prevention) +# --------------------------------------------------------------------------- - def to_untagged(self): - # Conversion happens at SemanticModel construction; self.table is - # already xorq when supported, plain ibis when not. - return self.table +@frozen +class _JoinTreeInfo: + """Information collected from the join tree for pre-aggregation decisions.""" -class SemanticFilterOp(Relation): - source: Relation - predicate: Callable + has_join_many: bool + table_cardinalities: dict # table_name → "one"|"many"|"root" + table_join_keys: dict # table_name → {raw_col_names} + table_ops: dict # table_name → SemanticTableOp - def __init__(self, source: Relation, predicate: Callable) -> None: - super().__init__( - source=Relation.__coerce__(source), - predicate=_ensure_wrapped(predicate), - ) - def __repr__(self) -> str: - return _semantic_repr(self) +def _collect_join_tree_info(join_op: SemanticJoinOp) -> _JoinTreeInfo: + """Walk the join tree to collect cardinality and join key information.""" + table_cardinalities: dict[str, str] = {} + table_ops: dict[str, SemanticTableOp] = {} - @property - def values(self) -> FrozenOrderedDict[str, Any]: - return self.source.values + def walk(node, is_right_of_many=False): + if isinstance(node, SemanticJoinOp): + this_is_many = node.cardinality in ("many", "cross") + walk(node.left, is_right_of_many=is_right_of_many) + walk(node.right, is_right_of_many=is_right_of_many or this_is_many) + elif isinstance(node, SemanticTableOp): + name = node.name + if name: + table_ops[name] = node + if is_right_of_many: + table_cardinalities[name] = "many" + elif name not in table_cardinalities: + table_cardinalities[name] = "one" - @property - def schema(self) -> Schema: - return self.source.schema + walk(join_op) - def to_untagged(self): - from .convert import _Resolver + # The leftmost leaf of the root is the "root" table + def find_leftmost(node): + if isinstance(node, SemanticJoinOp): + return find_leftmost(node.left) + return getattr(node, "name", None) - all_roots = _find_all_root_models(self.source) - base_tbl = _to_untagged(self.source) - dim_map = ( - {} - if isinstance(self.source, SemanticAggregateOp) - else _get_merged_fields(all_roots, "dimensions") - ) + root_name = find_leftmost(join_op) + if root_name: + table_cardinalities[root_name] = "root" - # Enrich table with derived dimensions so multi-level deps - # (e.g. d_two -> d_one -> distance) resolve correctly in filters. - # Best-effort: skip dimensions whose columns aren't available yet - # (e.g. join-based dims); those resolve through the Resolver fallback. - enriched = base_tbl - for dim_name in dim_map: - try: - enriched = _mutate_dimensions_with_dependencies( - enriched, [dim_name], dim_map - ) - except (TypeError, KeyError, AttributeError): - pass + has_join_many = any(c == "many" for c in table_cardinalities.values()) + # table_join_keys is used by the pre-aggregation path and by + # deferred join detection. Collect for all join trees. + table_join_keys = join_op._collect_join_keys_for_leaves() - pred_fn = _unwrap(self.predicate) - resolver = _Resolver(enriched, dim_map) - pred = _resolve_expr(pred_fn, resolver) - return enriched.filter(pred) + return _JoinTreeInfo( + has_join_many=has_join_many, + table_cardinalities=table_cardinalities, + table_join_keys=table_join_keys, + table_ops=table_ops, + ) - def get_dimensions(self) -> Mapping[str, Dimension]: - """Get dictionary of dimensions from source.""" - return self.source.get_dimensions() - def get_measures(self) -> Mapping[str, Measure]: - """Get dictionary of measures from source.""" - return self.source.get_measures() +@frozen +class _DeferrableJoin: + """A join_one that can be deferred until after aggregation.""" - def get_calculated_measures(self) -> Mapping[str, Any]: - """Get dictionary of calculated measures from source.""" - return self.source.get_calculated_measures() + table_name: str # Name of the right (dimension) table + table_op: Any # SemanticTableOp of the right table + join_keys_left: tuple # Column names on the left side of the join + on_predicate: Any # Original join predicate (preserves key pairing) + deferred_dims: tuple # Prefixed dimension names to add post-agg -def _classify_fields( - fields: tuple[str, ...], - dimensions: dict, - measures: dict, -) -> tuple[list[str], list[str], list[str]]: - """Classify fields into dimensions, measures, and raw columns.""" - dims = [f for f in fields if f in dimensions] - meas = [f for f in fields if f in measures] - raw = [f for f in fields if f not in dimensions and f not in measures] - return dims, meas, raw +def _find_deferrable_joins( + join_op, + group_by_keys: tuple[str, ...], + agg_names: dict, + all_roots: list, + join_tree_info: "_JoinTreeInfo", + filters: list | None = None, +) -> list[_DeferrableJoin]: + """Identify join_one ops that can be deferred until after aggregation. + A join is deferrable when: + - cardinality == "one" + - Right table has is_entity dims matching the join key + - No measures from right table are used in the aggregation + - No filters reference the right table's columns + """ + filters = filters or [] + deferrable: list[_DeferrableJoin] = [] -def _process_nested_access_marker( - marker: NestedAccessMarker, - name: str, - tbl: ir.Table, -) -> tuple[ir.Table, ir.Value]: - """Process a NestedAccessMarker to unnest and build aggregation expression.""" - unnested = tbl - for array_col in marker.array_path: - if array_col in unnested.columns: - unnested = unnested.unnest(array_col) + def _filter_references_table(table_name, table_op): + """Check if any filter predicate references columns from this table.""" + if not filters: + return False + raw_tbl = _to_untagged(table_op) + for pred in filters: + pred_fn = _unwrap(pred) + try: + # If the predicate can be resolved against this table, + # it references its columns → can't defer + pred_fn(raw_tbl) + return True + except Exception: + pass + return False - if marker.operation == "count": - return unnested, unnested.count().name(name) + def walk(node): + if not isinstance(node, SemanticJoinOp): + return + # Recurse into left side (may have nested joins) + walk(node.left) - expr = getattr(unnested, marker.array_path[0]) - for field_name in marker.field_path: - expr = getattr(expr, field_name) + if node.cardinality != "one": + return - if marker.operation in ("sum", "mean", "min", "max", "nunique"): - agg_fn = getattr(expr, marker.operation) - return unnested, agg_fn().name(name) + right = node.right + if not isinstance(right, SemanticTableOp): + return - raise ValueError(f"Unknown operation: {marker.operation}") + right_name = right.name + if not right_name: + return + # Check: no filters reference the right table + if _filter_references_table(right_name, right): + return -def _evaluate_measures_with_unnesting( - measure_names: list[str], - measures: dict, - tbl: ir.Table, -) -> dict: - """Evaluate measures and apply automatic unnesting if needed. + # Check: right table has is_entity dims + right_dims = right.get_dimensions() + entity_dims = frozenset( + name for name, dim in right_dims.items() + if getattr(dim, "is_entity", False) + ) + if not entity_dims: + return - Returns dict with: - - table: potentially unnested table - - measure_exprs: list of evaluated measure expressions - - needs_unnesting: whether unnesting occurred - """ - meas_exprs = [] - current_tbl = tbl - needs_unnesting = False + # Check: no measures from right table are used in aggregation + for agg_name in agg_names: + if agg_name.startswith(f"{right_name}."): + return # Right table has measures in agg → can't defer - for name in measure_names: - result = measures[name](tbl) + # Get join keys for the right table + right_join_keys = join_tree_info.table_join_keys.get(right_name, set()) + if not right_join_keys: + return - if isinstance(result, NestedAccessMarker): - current_tbl, meas_expr = _process_nested_access_marker(result, name, current_tbl) - meas_exprs.append(meas_expr) - needs_unnesting = True - else: - meas_exprs.append(result.name(name)) + # Check: entity dims match join keys (PK = join key) + if entity_dims != frozenset(right_join_keys): + return - return { - "table": current_tbl, - "measure_exprs": meas_exprs, - "needs_unnesting": needs_unnesting, - } + # Get left-side join keys by extracting from the predicate + try: + left_tbl = ( + node.left.to_untagged(parent_requirements=None) + if isinstance(node.left, SemanticJoinOp) + else _to_untagged(node.left) + ) + right_tbl = _to_untagged(right) + join_keys_result = _extract_join_key_columns(node.on, left_tbl, right_tbl) + if not join_keys_result.is_success(): + return + left_cols = tuple(sorted(join_keys_result.left_columns)) + except Exception: + logger.debug("join-defer eligibility check failed", exc_info=True) + return + # Identify which group-by dimensions from the right table will be deferred + deferred_dims = tuple( + k for k in group_by_keys + if k.startswith(f"{right_name}.") + ) -def _build_select_or_aggregate( - tbl: ir.Table, - dim_exprs: list, - meas_exprs: list, - raw_exprs: list, -) -> ir.Table: - """Build appropriate select/aggregate based on what expressions exist.""" - if meas_exprs and dim_exprs: - return tbl.group_by(dim_exprs).aggregate(meas_exprs) - if meas_exprs: - return tbl.aggregate(meas_exprs) - if dim_exprs or raw_exprs: - return tbl.select(dim_exprs + raw_exprs) - return tbl + # Only deferrable if there are dims to add post-agg + if not deferred_dims: + return + deferrable.append(_DeferrableJoin( + table_name=right_name, + table_op=right, + join_keys_left=left_cols, + on_predicate=node.on, + deferred_dims=deferred_dims, + )) -class SemanticProjectOp(Relation): - source: Relation - fields: tuple[str, ...] + walk(join_op) + return deferrable - def __init__(self, source: Relation, fields: Iterable[str]) -> None: - super().__init__(source=Relation.__coerce__(source), fields=tuple(fields)) - def __repr__(self) -> str: - return _semantic_repr(self) +def _left_join_bridge(left, bridge, common_keys): + """Left-join *bridge* onto *left*, selecting only new columns from bridge.""" + preds = [left[c] == bridge[c] for c in common_keys] + bridge_only = tuple(c for c in bridge.columns if c not in frozenset(common_keys)) + return left.left_join(bridge, preds).select([left] + [bridge[c] for c in bridge_only]) - @property - def values(self) -> FrozenOrderedDict[str, Any]: - src_vals = self.source.values - return FrozenOrderedDict( - {k: v for k, v in src_vals.items() if k in self.fields}, - ) - @property - def schema(self) -> Schema: - return _SchemaClass(fields=_FrozenOrderedDict({k: v.dtype for k, v in self.values.items()})) +def _find_chain_bridge(pt, gb_col, prefix, raw, measure_names, join_tree_info): + """Find an intermediate table that chains *pt* grain to *raw* (dim table). - def to_untagged(self): - all_roots = _find_all_root_models(self.source) - tbl = _to_untagged(self.source) + Returns the bridged table, or *pt* unchanged if no chain is found. + """ + current_grain = frozenset(c for c in pt.columns if c not in measure_names) + dim_keys = join_tree_info.table_join_keys.get(prefix, frozenset()) + raw_columns = frozenset(raw.columns) - if not all_roots: - return tbl.select([getattr(tbl, f) for f in self.fields]) + for tname, tkeys in join_tree_info.table_join_keys.items(): + overlap_dim = tkeys & dim_keys & raw_columns + overlap_grain = tkeys & current_grain + if not (overlap_dim and overlap_grain): + continue - merged_dimensions = _get_merged_fields(all_roots, "dimensions") - merged_measures = _get_merged_fields(all_roots, "measures") + inter_op = join_tree_info.table_ops.get(tname) + if inter_op is None: + continue - dims, meas, raw_fields = _classify_fields(self.fields, merged_dimensions, merged_measures) + inter_raw = _to_untagged(inter_op) + inter_cols = sorted(overlap_grain | overlap_dim) + inter_bridge = inter_raw.select( + [inter_raw[c] for c in inter_cols if c in inter_raw.columns] + ).distinct() - # Evaluate measures and handle automatic unnesting - meas_result = _evaluate_measures_with_unnesting(meas, merged_measures, tbl) + # Join dim table onto intermediate + dim_bridge_cols = sorted({gb_col} | overlap_dim) + dim_bridge = raw.select([raw[c] for c in dim_bridge_cols if c in raw.columns]).distinct() + chained = _left_join_bridge(inter_bridge, dim_bridge, sorted(overlap_dim)) - active_tbl = meas_result["table"] - meas_exprs = meas_result["measure_exprs"] - needs_unnesting = meas_result["needs_unnesting"] + # Join chained bridge onto pt — bridge on preserved (left) side + return _left_join_bridge(chained, pt, sorted(overlap_grain)) - # Re-evaluate dimensions on unnested table if needed - dim_exprs = ( - [merged_dimensions[name](active_tbl).name(name) for name in dims] - if needs_unnesting - else [merged_dimensions[name](tbl).name(name) for name in dims] - ) + return pt - # Get raw columns that still exist after unnesting - raw_exprs = [getattr(active_tbl, name) for name in raw_fields if name in active_tbl.columns] - return _build_select_or_aggregate(active_tbl, dim_exprs, meas_exprs, raw_exprs) +def _attach_dim_column(pt, gb_col, measure_names, join_tree_info, merged_dimensions): + """Attach a single group-by dimension column to a pre-agg result. + Looks up the raw table for the dimension's prefix, mutates the dim + column onto it, and bridges it to *pt* via shared join keys — either + directly or through an intermediate table. + """ + if "." not in gb_col: + return pt -class SemanticGroupByOp(Relation): - source: Relation - keys: tuple[str, ...] + prefix, _short = gb_col.split(".", 1) + dim_table_op = join_tree_info.table_ops.get(prefix) + dim_fn = merged_dimensions.get(gb_col) + if dim_table_op is None or dim_fn is None: + return pt - def __init__(self, source: Relation, keys: Iterable[str]) -> None: - super().__init__(source=Relation.__coerce__(source), keys=tuple(keys)) + raw = _mutate_dimensions_with_dependencies( + _to_untagged(dim_table_op), + [gb_col], + merged_dimensions, + ) + raw_columns = frozenset(raw.columns) + current_grain = tuple(c for c in pt.columns if c not in measure_names) + common_keys = tuple(c for c in current_grain if c in raw_columns) - def __repr__(self) -> str: - return _semantic_repr(self) + match common_keys: + case (): + # No direct overlap — chain through an intermediate table. + return _find_chain_bridge( + pt, + gb_col, + prefix, + raw, + measure_names, + join_tree_info, + ) + case _: + bridge = raw.select([raw[c] for c in (gb_col, *common_keys)]).distinct() + return _left_join_bridge(bridge, pt, common_keys) - @property - def values(self) -> FrozenOrderedDict[str, Any]: - return self.source.values - @property - def schema(self) -> Schema: - return self.source.schema +def _is_mean_expr(expr): + """Check if an ibis expression is a Mean/Average reduction.""" + try: + return isinstance(expr.op(), _reductions_for_expr(expr).Mean) + except (AttributeError, TypeError): + return False - def to_untagged(self): - return _to_untagged(self.source) +def _is_count_distinct_expr(expr): + """Check if an ibis expression is a CountDistinct (nunique) reduction.""" + return safe( + lambda: isinstance(expr.op(), _reductions_for_expr(expr).CountDistinct) + )().value_or(False) -@frozen -class _MeasureSpec: - name: str - kind: str # 'agg' or 'calc' - value: Any +def _reagg_op_for_expr(expr): + """Return the correct re-aggregation operation name for an ibis expression. -@frozen -class _AggregationPlan: - agg_specs: FrozenDict[str, Callable] - calc_specs: FrozenDict[str, Any] - requested_measures: tuple[str, ...] - group_by_cols: tuple[str, ...] + Additive measures (SUM, COUNT) re-aggregate with ``sum``. + MIN and MAX re-aggregate with ``min`` and ``max`` respectively. + MEAN should never reach here — it is decomposed by ``_is_mean_expr``. + """ + op = expr.op() + reductions = _reductions_for_expr(expr) + if isinstance(op, reductions.Min): + return "min" + if isinstance(op, reductions.Max): + return "max" + if isinstance(op, reductions.Mean): + raise ValueError( + f"Mean expression {expr.get_name()!r} was not decomposed — " + "this is a bug in the pre-aggregation logic" + ) + if isinstance(op, reductions.CountDistinct): + raise ValueError( + f"CountDistinct expression {expr.get_name()!r} was not deferred — " + "this is a bug in the pre-aggregation logic" + ) + return "sum" -def _resolve_aggregation_exprs( - expr: Any, - merged_base_measures: dict, - merged_calc_measures: dict, - tbl: ir.Table, -) -> Any: - @curry - def find_in_calc_measures(expr, calc_measures): - for calc_name, calc_expr in calc_measures.items(): - if isinstance(calc_expr, AggregationExpr) and ( - calc_expr.column == expr.column and calc_expr.operation == expr.operation - ): - return Some(calc_name) - return Nothing +def _build_reagg(col_ref, op_name): + """Apply the correct re-aggregation to a column reference.""" + return getattr(col_ref, op_name)() - def resolve_aggregation(agg_expr): - matched = _find_matching_measure(agg_expr, merged_base_measures, tbl) - return matched.map(MeasureRef).value_or( - find_in_calc_measures(agg_expr, merged_calc_measures).map(MeasureRef).value_or(agg_expr) - ) - if isinstance(expr, AggregationExpr): - return resolve_aggregation(expr) - elif isinstance(expr, MethodCall): - return MethodCall( - receiver=_resolve_aggregation_exprs( - expr.receiver, merged_base_measures, merged_calc_measures, tbl - ), - method=expr.method, - args=expr.args, - kwargs=expr.kwargs, - ) - elif isinstance(expr, BinOp): - return BinOp( - op=expr.op, - left=_resolve_aggregation_exprs( - expr.left, merged_base_measures, merged_calc_measures, tbl - ), - right=_resolve_aggregation_exprs( - expr.right, merged_base_measures, merged_calc_measures, tbl - ), - ) - elif isinstance(expr, AllOf) and isinstance(expr.ref, AggregationExpr): - return AllOf(resolve_aggregation(expr.ref)) - else: - return expr +def _partition_agg_specs_by_source( + agg_specs: dict[str, Callable], + all_roots: list[SemanticTableOp], +) -> dict[str | None, dict[str, Callable]]: + """Partition aggregation specs by their source table. + Prefixed measure names like ``"orders.total_amount"`` are mapped to + ``table="orders"``. Measures without a prefix go to ``None``. + """ + root_names = {r.name for r in all_roots if r.name} + partitioned: dict[str | None, dict[str, Callable]] = {} -def _create_measure_spec( - name: str, - fn_wrapped: Any, - scope: Any, - is_post_agg: bool, - merged_base_measures: dict, - merged_calc_measures: dict, - tbl: ir.Table, -) -> _MeasureSpec: - fn = _unwrap(fn_wrapped) - val = _resolve_expr(fn, scope) - val = _resolve_aggregation_exprs(val, merged_base_measures, merged_calc_measures, tbl) + for measure_name, fn in agg_specs.items(): + table_name = None + if "." in measure_name: + prefix = measure_name.split(".", 1)[0] + if prefix in root_names: + table_name = prefix + if table_name not in partitioned: + partitioned[table_name] = {} + partitioned[table_name][measure_name] = fn - if is_post_agg: - return _MeasureSpec(name=name, kind="agg", value=fn) + return partitioned - if isinstance(val, MeasureRef): - ref_name = val.name - if ref_name in merged_calc_measures: - calc_expr = merged_calc_measures[ref_name] - resolved = _resolve_aggregation_exprs( - calc_expr, merged_base_measures, merged_calc_measures, tbl - ) - return _MeasureSpec(name=name, kind="calc", value=resolved) - elif ref_name in merged_base_measures: - return _MeasureSpec(name=name, kind="agg", value=merged_base_measures[ref_name]) - else: - return _MeasureSpec(name=name, kind="calc", value=val) - if isinstance(val, AllOf | BinOp | MethodCall | int | float): - return _MeasureSpec(name=name, kind="calc", value=val) +class SemanticAggregateOp(Relation): + source: Relation + keys: tuple[str, ...] + aggs: dict[ + str, + Callable, + ] # Transformed to FrozenDict[str, _CallableWrapper] in __init__ + nested_columns: tuple[str, ...] = () # Track which columns are nested arrays - return _MeasureSpec(name=name, kind="agg", value=fn) + def __init__( + self, + source: Relation, + keys: Iterable[str], + aggs: dict[str, Callable] | None, + nested_columns: Iterable[str] | None = None, + ) -> None: + frozen_aggs = FrozenDict( + {name: _ensure_wrapped(fn) for name, fn in (aggs or {}).items()}, + ) + super().__init__( + source=Relation.__coerce__(source), + keys=tuple(keys), + aggs=frozen_aggs, + nested_columns=tuple(nested_columns or []), + ) + def __repr__(self) -> str: + return _semantic_repr(self) -def _make_agg_callable(measure: Any) -> Callable: - if _is_deferred(measure): - return lambda t: measure.resolve(ColumnScope(_tbl=t)) - elif callable(measure): - return lambda t: measure(ColumnScope(_tbl=t)) - else: - return lambda t: measure(t) + @property + def values(self) -> FrozenOrderedDict[str, Any]: + tbl = self.to_untagged() + return FrozenOrderedDict({col: tbl[col].op() for col in tbl.columns}) + @property + def schema(self) -> Schema: + return _SchemaClass(fields=_FrozenOrderedDict({n: v.dtype for n, v in self.values.items()})) -def _collect_all_measure_refs(calc_exprs) -> frozenset[str]: - all_refs = set() - for expr in calc_exprs: - _collect_measure_refs(expr, all_refs) - return frozenset(all_refs) + @property + def measures(self) -> tuple[str, ...]: + return () + def get_dimensions(self) -> Mapping[str, Dimension]: + """After aggregation, dimensions are materialized - return empty.""" + return {} -def _expand_calc_measure_refs( - expr: Any, - merged_base_measures: dict, - merged_calc_measures: dict, - tbl: ir.Table, - cache: dict[str, Any] | None = None, - path: tuple[str, ...] = (), -) -> Any: - """Inline calc-measure references transitively for multi-layer formulas.""" - cache = {} if cache is None else cache + def get_measures(self) -> Mapping[str, Measure]: + """After aggregation, measures are materialized - return empty.""" + return {} - def _lift_to_allof(value: Any) -> Any: - """Lift an expanded expression into totals-space via AllOf on refs.""" - if isinstance(value, MeasureRef): - return AllOf(value) - if isinstance(value, BinOp): - return BinOp( - op=value.op, - left=_lift_to_allof(value.left), - right=_lift_to_allof(value.right), - ) - if isinstance(value, MethodCall): - return MethodCall( - receiver=_lift_to_allof(value.receiver), - method=value.method, - args=value.args, - kwargs=value.kwargs, - ) - return value + def get_calculated_measures(self) -> Mapping[str, Any]: + """After aggregation, calculated measures are materialized - return empty.""" + return {} - if isinstance(expr, MeasureRef): - ref_name = expr.name - if ref_name not in merged_calc_measures: - return expr - if ref_name in cache: - return cache[ref_name] - if ref_name in path: - cycle = " -> ".join((*path, ref_name)) - raise ValueError(f"Circular calculated measure dependency detected: {cycle}") + @property + def required_columns(self) -> dict[str, set[str]]: + """ + Column requirements for this aggregation operation. - resolved = _resolve_aggregation_exprs( - merged_calc_measures[ref_name], merged_base_measures, merged_calc_measures, tbl - ) - expanded = _expand_calc_measure_refs( - resolved, - merged_base_measures, - merged_calc_measures, - tbl, - cache, - (*path, ref_name), + This property makes column requirements intrinsic to the aggregate operation, + similar to how `schema` is intrinsic to a relation. + + Returns: + Dict mapping table names to sets of required column names. + """ + all_roots = _find_all_root_models(self.source) + merged_dimensions = _get_merged_fields(all_roots, "dimensions") + + base_tbl = ( + self.source.to_expr() if hasattr(self.source, "to_expr") else _to_untagged(self.source) ) - cache[ref_name] = expanded - return expanded - if isinstance(expr, MethodCall): - return MethodCall( - receiver=_expand_calc_measure_refs( - expr.receiver, merged_base_measures, merged_calc_measures, tbl, cache, path - ), - method=expr.method, - args=expr.args, - kwargs=expr.kwargs, + table_names = [] + for root in all_roots: + if root.name: + table_names.append(root.name) + elif root._source_join is not None: + # All roots are SemanticTableOp with _source_join attribute + join_roots = _find_all_root_models(root._source_join) + table_names.extend([r.name for r in join_roots if r.name]) + + key_requirements = projection_utils.extract_requirements_from_keys( + keys=list(self.keys), + dimensions=merged_dimensions, + table=base_tbl, + table_names=table_names, ) - if isinstance(expr, BinOp): - return BinOp( - op=expr.op, - left=_expand_calc_measure_refs( - expr.left, merged_base_measures, merged_calc_measures, tbl, cache, path - ), - right=_expand_calc_measure_refs( - expr.right, merged_base_measures, merged_calc_measures, tbl, cache, path - ), + measure_requirements = projection_utils.extract_requirements_from_measures( + measures={name: _unwrap(fn) for name, fn in self.aggs.items()}, + table=base_tbl, + table_names=table_names, ) - if isinstance(expr, AllOf): - if isinstance(expr.ref, MeasureRef): - expanded_ref = _expand_calc_measure_refs( - expr.ref, merged_base_measures, merged_calc_measures, tbl, cache, path - ) - if isinstance(expanded_ref, MeasureRef): - return AllOf(expanded_ref) - return _lift_to_allof(expanded_ref) - return expr + combined = key_requirements.merge(measure_requirements) - return expr + if hasattr(self.source, "required_columns"): + source_reqs = projection_utils.TableRequirements.from_dict(self.source.required_columns) + combined = combined.merge(source_reqs) + return combined.to_dict() -def _build_aggregation_plan( - aggs: dict, - keys: tuple, - scope: Any, - is_post_agg: bool, - merged_base_measures: dict, - merged_calc_measures: dict, - tbl: ir.Table, -) -> _AggregationPlan: - specs = [ - _create_measure_spec( - name, fn, scope, is_post_agg, merged_base_measures, merged_calc_measures, tbl - ) - for name, fn in aggs.items() - ] - - agg_specs_list = [s for s in specs if s.kind == "agg"] - calc_specs_list = [s for s in specs if s.kind == "calc"] + def to_untagged(self): + all_roots = _find_all_root_models(self.source) - agg_specs = FrozenDict({s.name: _make_agg_callable(s.value) for s in agg_specs_list}) - calc_specs = FrozenDict({s.name: s.value for s in calc_specs_list}) + def find_join_in_tree(node): + """Find a SemanticJoinOp in the operation tree. - calc_cache: dict[str, Any] = {} - expanded_calc_specs = FrozenDict( - { - name: _expand_calc_measure_refs( - expr, - merged_base_measures, - merged_calc_measures, - tbl, - cache=calc_cache, - path=(name,), - ) - for name, expr in calc_specs.items() - } - ) + All Relation subclasses have source attribute except leaf operations. + """ + if isinstance(node, SemanticJoinOp): + return node + if isinstance(node, SemanticTableOp): + # Leaf operations - no source to traverse + return None + if node.source is not None: + return find_join_in_tree(node.source) + return None - referenced = _collect_all_measure_refs(expanded_calc_specs.values()) - additional_aggs = { - ref: _make_agg_callable(merged_base_measures[ref]) - for ref in referenced - if ref not in agg_specs and ref in merged_base_measures - } + def collect_filters_to_join(node): + """Collect filter predicates between this node and the join. - final_agg_specs = FrozenDict({**agg_specs, **additional_aggs}) + Returns tuple of filter predicate wrappers found between the + aggregate and the underlying join/table. + """ + filters = [] + current = node + while current is not None: + match current: + case SemanticFilterOp(): + filters.append(current.predicate) + current = current.source + case SemanticJoinOp() | SemanticTableOp(): + break + case SemanticGroupByOp(): + current = current.source + case _ if hasattr(current, "source"): + current = current.source + case _: + break + return tuple(filters) - return _AggregationPlan( - agg_specs=final_agg_specs, - calc_specs=expanded_calc_specs, - requested_measures=tuple(aggs.keys()), - group_by_cols=tuple(keys), - ) + join_op = find_join_in_tree(self.source) + if join_op is None and isinstance(self.source, SemanticGroupByOp): + grouped_source = self.source.source + if isinstance(grouped_source, SemanticTableOp): + # SemanticTableOp always has _source_join attribute + join_op = grouped_source._source_join -# --------------------------------------------------------------------------- -# Pre-aggregation helpers (fan-out / chasm trap prevention) -# --------------------------------------------------------------------------- + def has_prior_aggregate(node): + """Recursively check if there's a SemanticAggregateOp before any mutate.""" + if isinstance(node, SemanticAggregateOp): + return True + if isinstance(node, SemanticMutateOp): + return has_prior_aggregate(node.source) + if isinstance(node, SemanticGroupByOp): + return has_prior_aggregate(node.source) + if isinstance(node, SemanticTableOp | SemanticJoinOp): + return False + if hasattr(node, "source"): + return has_prior_aggregate(node.source) + return False + is_post_agg = has_prior_aggregate(self.source) + collected_filters = collect_filters_to_join(self.source) -@frozen -class _JoinTreeInfo: - """Information collected from the join tree for pre-aggregation decisions.""" + def collect_mutates_to_join(node): + """Collect SemanticMutateOp.post dicts between here and the join.""" + mutates = [] + current = node + while current is not None: + match current: + case SemanticMutateOp(): + mutates.append(current.post) + current = current.source + case SemanticJoinOp() | SemanticTableOp(): + break + case _ if hasattr(current, "source"): + current = current.source + case _: + break + return tuple(mutates) - has_join_many: bool - table_cardinalities: dict # table_name → "one"|"many"|"root" - table_join_keys: dict # table_name → {raw_col_names} - table_ops: dict # table_name → SemanticTableOp + # Dimension-only shortcut: when no measures are requested and all + # dimensions originate from a single joined table, query that table + # directly so that dimension members with no matching fact rows are + # still returned. (Fixes #224.) + if join_op is not None and not is_post_agg and not self.aggs: + shortcut = _dimension_only_source_table( + self.keys, all_roots, collected_filters, + ) + if shortcut is not None: + root_op, unprefixed_keys, dim_filters = shortcut + try: + tbl = _to_untagged(root_op) + root_dims = root_op.get_dimensions() + tbl = _mutate_dimensions_with_dependencies( + tbl, unprefixed_keys, root_dims, + ) + # Apply pre-aggregation filters on the dimension table. + for flt in dim_filters: + fn = _unwrap(flt) if hasattr(flt, "unwrap") else flt + tbl = tbl.filter(fn(tbl)) + result = tbl.select(unprefixed_keys).distinct() + # Rename columns to their prefixed (dotted) names so that + # downstream consumers see the expected column names. + prefix = root_op.name + rename_map = {f"{prefix}.{uk}": uk for uk in unprefixed_keys} + return result.rename(rename_map) + except Exception: + logger.debug( + "dimension-only shortcut failed for keys=%s; " + "falling back to standard path", + self.keys, + exc_info=True, + ) + # Pre-aggregation path: when join_many is present, aggregate each + # source table's measures at its own grain before joining. + if join_op is not None and not is_post_agg: + join_tree_info = _collect_join_tree_info(join_op) + if join_tree_info.has_join_many: + collected_mutates = collect_mutates_to_join(self.source) + return self._to_untagged_with_preagg( + all_roots, + join_op, + join_tree_info, + filters=collected_filters, + mutates=collected_mutates, + ) -def _collect_join_tree_info(join_op: SemanticJoinOp) -> _JoinTreeInfo: - """Walk the join tree to collect cardinality and join key information.""" - table_cardinalities: dict[str, str] = {} - table_ops: dict[str, SemanticTableOp] = {} + # Deferred join path: when join_one dimension tables can be + # joined AFTER aggregation for better performance. + deferrable = _find_deferrable_joins( + join_op, self.keys, self.aggs, all_roots, join_tree_info, + filters=collected_filters, + ) + if deferrable: + return self._to_untagged_with_deferred_joins( + all_roots, + join_op, + join_tree_info, + deferrable, + filters=collected_filters, + ) - def walk(node, is_right_of_many=False): - if isinstance(node, SemanticJoinOp): - this_is_many = node.cardinality in ("many", "cross") - walk(node.left, is_right_of_many=is_right_of_many) - walk(node.right, is_right_of_many=is_right_of_many or this_is_many) - elif isinstance(node, SemanticTableOp): - name = node.name - if name: - table_ops[name] = node - if is_right_of_many: - table_cardinalities[name] = "many" - elif name not in table_cardinalities: - table_cardinalities[name] = "one" + # Only use the join optimization if there are no filters after the join + # Otherwise we'd skip the filter operations + needed_tables: dict[str, set[str]] = {} + if join_op is not None and not collected_filters: + # Build table-level requirements from prefixed keys and measure + # names so the join can prune tables that the query never touches. + # When all names are unprefixed (single-table or post-agg), the + # dict is empty and `or None` disables pruning — correct since + # there's nothing to prune in that case. + for name in (*self.keys, *self.aggs.keys()): + if "." in name: + tbl_prefix = name.split(".", 1)[0] + needed_tables.setdefault(tbl_prefix, set()).add(name) + tbl = join_op.to_untagged(parent_requirements=needed_tables or None) + else: + tbl = _to_untagged(self.source) - walk(join_op) + merged_dimensions = _get_merged_fields(all_roots, "dimensions") + merged_base_measures = _get_merged_fields(all_roots, "measures") + merged_calc_measures = _get_merged_fields(all_roots, "calc_measures") - # The leftmost leaf of the root is the "root" table - def find_leftmost(node): - if isinstance(node, SemanticJoinOp): - return find_leftmost(node.left) - return getattr(node, "name", None) + if needed_tables: + # Keep validation behavior for queried tables even when join pruning + # removes unrelated siblings. Validate each referenced root table + # against its own leaf schema so wrapped join aliases do not affect + # the error surface for local dimensions. + for root in all_roots: + if root.name not in needed_tables: + continue + root_dimensions = _get_field_dict(root, "dimensions") + if root_dimensions: + _mutate_dimensions_with_dependencies( + _to_untagged(root), + root_dimensions.keys(), + root_dimensions, + ) - root_name = find_leftmost(join_op) - if root_name: - table_cardinalities[root_name] = "root" + tbl = _mutate_dimensions_with_dependencies( + tbl, + [k for k in self.keys if k in merged_dimensions], + merged_dimensions, + ) - has_join_many = any(c == "many" for c in table_cardinalities.values()) - # table_join_keys is used by the pre-aggregation path and by - # deferred join detection. Collect for all join trees. - table_join_keys = join_op._collect_join_keys_for_leaves() + scope = ( + ColumnScope(_tbl=tbl) + if is_post_agg + else MeasureScope( + _tbl=tbl, + _known=list(merged_base_measures.keys()) + list(merged_calc_measures.keys()), + ) + ) - return _JoinTreeInfo( - has_join_many=has_join_many, - table_cardinalities=table_cardinalities, - table_join_keys=table_join_keys, - table_ops=table_ops, - ) - - -@frozen -class _DeferrableJoin: - """A join_one that can be deferred until after aggregation.""" - - table_name: str # Name of the right (dimension) table - table_op: Any # SemanticTableOp of the right table - join_keys_left: tuple # Column names on the left side of the join - on_predicate: Any # Original join predicate (preserves key pairing) - deferred_dims: tuple # Prefixed dimension names to add post-agg - - -def _find_deferrable_joins( - join_op, - group_by_keys: tuple[str, ...], - agg_names: dict, - all_roots: list, - join_tree_info: "_JoinTreeInfo", - filters: list | None = None, -) -> list[_DeferrableJoin]: - """Identify join_one ops that can be deferred until after aggregation. - - A join is deferrable when: - - cardinality == "one" - - Right table has is_entity dims matching the join key - - No measures from right table are used in the aggregation - - No filters reference the right table's columns - """ - filters = filters or [] - deferrable: list[_DeferrableJoin] = [] - - def _filter_references_table(table_name, table_op): - """Check if any filter predicate references columns from this table.""" - if not filters: - return False - raw_tbl = _to_untagged(table_op) - for pred in filters: - pred_fn = _unwrap(pred) - try: - # If the predicate can be resolved against this table, - # it references its columns → can't defer - pred_fn(raw_tbl) - return True - except Exception: - pass - return False - - def walk(node): - if not isinstance(node, SemanticJoinOp): - return - # Recurse into left side (may have nested joins) - walk(node.left) + plan = _build_aggregation_plan( + aggs=self.aggs, + keys=self.keys, + scope=scope, + is_post_agg=is_post_agg, + merged_base_measures=merged_base_measures, + merged_calc_measures=merged_calc_measures, + tbl=tbl, + ) - if node.cardinality != "one": - return + if plan.calc_specs or plan.group_by_cols: + return compile_grouped_with_all( + tbl, + list(plan.group_by_cols), + dict(plan.agg_specs), + dict(plan.calc_specs), + requested_measures=list(plan.requested_measures), + ) + else: + return tbl.aggregate({name: fn(tbl) for name, fn in plan.agg_specs.items()}) - right = node.right - if not isinstance(right, SemanticTableOp): - return + def _to_untagged_with_preagg( + self, + all_roots: list, + join_op: SemanticJoinOp, + join_tree_info: _JoinTreeInfo, + filters: list | None = None, + mutates: tuple | None = None, + ): + """Pre-aggregate each source table's measures at its own grain, then join. - right_name = right.name - if not right_name: - return + This prevents fan-out inflation when ``join_many`` is used. + """ + merged_dimensions = _get_merged_fields(all_roots, "dimensions") + merged_base_measures = _get_merged_fields(all_roots, "measures") + merged_calc_measures = _get_merged_fields(all_roots, "calc_measures") + group_by_cols = list(self.keys) - # Check: no filters reference the right table - if _filter_references_table(right_name, right): - return + filters = filters or [] + mutates = mutates or () - # Check: right table has is_entity dims - right_dims = right.get_dimensions() - entity_dims = frozenset( - name for name, dim in right_dims.items() - if getattr(dim, "is_entity", False) + # Identify group-by keys that come from .mutate() rather than + # semantic dimensions — these need special handling. This is a + # process-of-elimination heuristic: any key that is not a known + # dimension, measure, or calc measure is assumed to originate + # from a .mutate() call in the operation chain. + mutated_gb_keys = frozenset( + k for k in group_by_cols + if k not in merged_dimensions + and k not in merged_base_measures + and k not in merged_calc_measures ) - if not entity_dims: - return - - # Check: no measures from right table are used in aggregation - for agg_name in agg_names: - if agg_name.startswith(f"{right_name}."): - return # Right table has measures in agg → can't defer - - # Get join keys for the right table - right_join_keys = join_tree_info.table_join_keys.get(right_name, set()) - if not right_join_keys: - return - - # Check: entity dims match join keys (PK = join key) - if entity_dims != frozenset(right_join_keys): - return - # Get left-side join keys by extracting from the predicate + # --- 1. Try to build the full joined table (for scope / dim bridge) --- + # Pre-agg needs all tables for dimension bridges — no pruning here. try: - left_tbl = ( - node.left.to_untagged(parent_requirements=None) - if isinstance(node.left, SemanticJoinOp) - else _to_untagged(node.left) + tbl = join_op.to_untagged(parent_requirements=None) + tbl = _mutate_dimensions_with_dependencies( + tbl, + [k for k in self.keys if k in merged_dimensions], + merged_dimensions, ) - right_tbl = _to_untagged(right) - join_keys_result = _extract_join_key_columns(node.on, left_tbl, right_tbl) - if not join_keys_result.is_success(): - return - left_cols = tuple(sorted(join_keys_result.left_columns)) - except Exception: - logger.debug("join-defer eligibility check failed", exc_info=True) - return + # Apply mutate operations so mutated group-by keys are available + # on the full joined table (needed for dimension bridges). + if mutated_gb_keys: + for mutate_post in mutates: + for name, fn_wrapped in mutate_post.items(): + if name in mutated_gb_keys: + try: + fn = _unwrap(fn_wrapped) + resolved = _resolve_expr(fn, tbl) + tbl = tbl.mutate(**{name: resolved}) + except (TypeError, KeyError, AttributeError, + ibis.common.exceptions.IbisTypeError): + pass + # Apply collected filters to the full joined table so that + # dimension bridges only include rows surviving the filter. + if filters: + from ..convert import _Resolver - # Identify which group-by dimensions from the right table will be deferred - deferred_dims = tuple( - k for k in group_by_keys - if k.startswith(f"{right_name}.") - ) + for pred in filters: + pred_fn = _unwrap(pred) + resolver = _Resolver(tbl, merged_dimensions) + pred_expr = _resolve_expr(pred_fn, resolver) + tbl = tbl.filter(pred_expr) + except Exception: + # chasm / column collision – work without full join + logger.debug("full-join construction failed; using per-table path", exc_info=True) + tbl = None - # Only deferrable if there are dims to add post-agg - if not deferred_dims: - return + # --- 2. Build aggregation plan --- + if tbl is not None: + scope = MeasureScope( + _tbl=tbl, + _known=list(merged_base_measures.keys()) + list(merged_calc_measures.keys()), + ) + plan = _build_aggregation_plan( + aggs=self.aggs, + keys=self.keys, + scope=scope, + is_post_agg=False, + merged_base_measures=merged_base_measures, + merged_calc_measures=merged_calc_measures, + tbl=tbl, + ) + else: + # Derive plan directly from metadata (chasm fallback) + agg_specs = {} + for name, fn_wrapped in self.aggs.items(): + if name in merged_base_measures: + agg_specs[name] = _make_agg_callable(merged_base_measures[name]) + plan = _AggregationPlan( + agg_specs=FrozenDict(agg_specs), + calc_specs=FrozenDict({}), + requested_measures=tuple(self.aggs.keys()), + group_by_cols=tuple(self.keys), + ) - deferrable.append(_DeferrableJoin( - table_name=right_name, - table_op=right, - join_keys_left=left_cols, - on_predicate=node.on, - deferred_dims=deferred_dims, - )) + # --- 3. Partition agg_specs by source table --- + partitioned = _partition_agg_specs_by_source(dict(plan.agg_specs), all_roots) - walk(join_op) - return deferrable + # --- 4. Pre-aggregate each source table on its raw table --- + _preagg_results: list = [] + # Track MEAN measures decomposed into SUM + COUNT for correct re-agg + _decomposed_means: dict[str, tuple[str, str]] = {} + # Track correct re-aggregation op per measure (default "sum") + _reagg_ops: dict[str, str] = {} + # Track COUNT DISTINCT measures deferred past pre-aggregation. + # Value: (table_name, short_name, raw_tbl, measure_fn) + _deferred_count_distincts: dict[str, tuple] = {} + for table_name, measures in partitioned.items(): + if table_name is None: + # Unprefixed – aggregate on the full join if available + if tbl is not None: + agg_exprs = {n: f(tbl) for n, f in measures.items()} + if group_by_cols: + r = tbl.group_by([tbl[c] for c in group_by_cols]).aggregate(**agg_exprs) + else: + r = tbl.aggregate(**agg_exprs) + _preagg_results.append(r) + continue -def _left_join_bridge(left, bridge, common_keys): - """Left-join *bridge* onto *left*, selecting only new columns from bridge.""" - preds = [left[c] == bridge[c] for c in common_keys] - bridge_only = tuple(c for c in bridge.columns if c not in frozenset(common_keys)) - return left.left_join(bridge, preds).select([left] + [bridge[c] for c in bridge_only]) + table_op = join_tree_info.table_ops.get(table_name) + if table_op is None: + continue + raw_tbl = _to_untagged(table_op) -def _find_chain_bridge(pt, gb_col, prefix, raw, measure_names, join_tree_info): - """Find an intermediate table that chains *pt* grain to *raw* (dim table). - - Returns the bridged table, or *pt* unchanged if no chain is found. - """ - current_grain = frozenset(c for c in pt.columns if c not in measure_names) - dim_keys = join_tree_info.table_join_keys.get(prefix, frozenset()) - raw_columns = frozenset(raw.columns) - - for tname, tkeys in join_tree_info.table_join_keys.items(): - overlap_dim = tkeys & dim_keys & raw_columns - overlap_grain = tkeys & current_grain - if not (overlap_dim and overlap_grain): - continue - - inter_op = join_tree_info.table_ops.get(tname) - if inter_op is None: - continue - - inter_raw = _to_untagged(inter_op) - inter_cols = sorted(overlap_grain | overlap_dim) - inter_bridge = inter_raw.select( - [inter_raw[c] for c in inter_cols if c in inter_raw.columns] - ).distinct() - - # Join dim table onto intermediate - dim_bridge_cols = sorted({gb_col} | overlap_dim) - dim_bridge = raw.select([raw[c] for c in dim_bridge_cols if c in raw.columns]).distinct() - chained = _left_join_bridge(inter_bridge, dim_bridge, sorted(overlap_dim)) - - # Join chained bridge onto pt — bridge on preserved (left) side - return _left_join_bridge(chained, pt, sorted(overlap_grain)) - - return pt - - -def _attach_dim_column(pt, gb_col, measure_names, join_tree_info, merged_dimensions): - """Attach a single group-by dimension column to a pre-agg result. - - Looks up the raw table for the dimension's prefix, mutates the dim - column onto it, and bridges it to *pt* via shared join keys — either - directly or through an intermediate table. - """ - if "." not in gb_col: - return pt - - prefix, _short = gb_col.split(".", 1) - dim_table_op = join_tree_info.table_ops.get(prefix) - dim_fn = merged_dimensions.get(gb_col) - if dim_table_op is None or dim_fn is None: - return pt - - raw = _mutate_dimensions_with_dependencies( - _to_untagged(dim_table_op), - [gb_col], - merged_dimensions, - ) - raw_columns = frozenset(raw.columns) - current_grain = tuple(c for c in pt.columns if c not in measure_names) - common_keys = tuple(c for c in current_grain if c in raw_columns) - - match common_keys: - case (): - # No direct overlap — chain through an intermediate table. - return _find_chain_bridge( - pt, - gb_col, - prefix, - raw, - measure_names, - join_tree_info, - ) - case _: - bridge = raw.select([raw[c] for c in (gb_col, *common_keys)]).distinct() - return _left_join_bridge(bridge, pt, common_keys) - - -def _is_mean_expr(expr): - """Check if an ibis expression is a Mean/Average reduction.""" - try: - return isinstance(expr.op(), _reductions_for_expr(expr).Mean) - except (AttributeError, TypeError): - return False - - -def _is_count_distinct_expr(expr): - """Check if an ibis expression is a CountDistinct (nunique) reduction.""" - return safe( - lambda: isinstance(expr.op(), _reductions_for_expr(expr).CountDistinct) - )().value_or(False) - - -def _reagg_op_for_expr(expr): - """Return the correct re-aggregation operation name for an ibis expression. - - Additive measures (SUM, COUNT) re-aggregate with ``sum``. - MIN and MAX re-aggregate with ``min`` and ``max`` respectively. - MEAN should never reach here — it is decomposed by ``_is_mean_expr``. - """ - op = expr.op() - reductions = _reductions_for_expr(expr) - if isinstance(op, reductions.Min): - return "min" - if isinstance(op, reductions.Max): - return "max" - if isinstance(op, reductions.Mean): - raise ValueError( - f"Mean expression {expr.get_name()!r} was not decomposed — " - "this is a bug in the pre-aggregation logic" - ) - if isinstance(op, reductions.CountDistinct): - raise ValueError( - f"CountDistinct expression {expr.get_name()!r} was not deferred — " - "this is a bug in the pre-aggregation logic" - ) - return "sum" - - -def _build_reagg(col_ref, op_name): - """Apply the correct re-aggregation to a column reference.""" - return getattr(col_ref, op_name)() - - -def _partition_agg_specs_by_source( - agg_specs: dict[str, Callable], - all_roots: list[SemanticTableOp], -) -> dict[str | None, dict[str, Callable]]: - """Partition aggregation specs by their source table. - - Prefixed measure names like ``"orders.total_amount"`` are mapped to - ``table="orders"``. Measures without a prefix go to ``None``. - """ - root_names = {r.name for r in all_roots if r.name} - partitioned: dict[str | None, dict[str, Callable]] = {} - - for measure_name, fn in agg_specs.items(): - table_name = None - if "." in measure_name: - prefix = measure_name.split(".", 1)[0] - if prefix in root_names: - table_name = prefix - if table_name not in partitioned: - partitioned[table_name] = {} - partitioned[table_name][measure_name] = fn - - return partitioned - - -class SemanticAggregateOp(Relation): - source: Relation - keys: tuple[str, ...] - aggs: dict[ - str, - Callable, - ] # Transformed to FrozenDict[str, _CallableWrapper] in __init__ - nested_columns: tuple[str, ...] = () # Track which columns are nested arrays - - def __init__( - self, - source: Relation, - keys: Iterable[str], - aggs: dict[str, Callable] | None, - nested_columns: Iterable[str] | None = None, - ) -> None: - frozen_aggs = FrozenDict( - {name: _ensure_wrapped(fn) for name, fn in (aggs or {}).items()}, - ) - super().__init__( - source=Relation.__coerce__(source), - keys=tuple(keys), - aggs=frozen_aggs, - nested_columns=tuple(nested_columns or []), - ) - - def __repr__(self) -> str: - return _semantic_repr(self) - - @property - def values(self) -> FrozenOrderedDict[str, Any]: - tbl = self.to_untagged() - return FrozenOrderedDict({col: tbl[col].op() for col in tbl.columns}) - - @property - def schema(self) -> Schema: - return _SchemaClass(fields=_FrozenOrderedDict({n: v.dtype for n, v in self.values.items()})) - - @property - def measures(self) -> tuple[str, ...]: - return () - - def get_dimensions(self) -> Mapping[str, Dimension]: - """After aggregation, dimensions are materialized - return empty.""" - return {} - - def get_measures(self) -> Mapping[str, Measure]: - """After aggregation, measures are materialized - return empty.""" - return {} - - def get_calculated_measures(self) -> Mapping[str, Any]: - """After aggregation, calculated measures are materialized - return empty.""" - return {} - - @property - def required_columns(self) -> dict[str, set[str]]: - """ - Column requirements for this aggregation operation. - - This property makes column requirements intrinsic to the aggregate operation, - similar to how `schema` is intrinsic to a relation. - - Returns: - Dict mapping table names to sets of required column names. - """ - all_roots = _find_all_root_models(self.source) - merged_dimensions = _get_merged_fields(all_roots, "dimensions") - - base_tbl = ( - self.source.to_expr() if hasattr(self.source, "to_expr") else _to_untagged(self.source) - ) - - table_names = [] - for root in all_roots: - if root.name: - table_names.append(root.name) - elif root._source_join is not None: - # All roots are SemanticTableOp with _source_join attribute - join_roots = _find_all_root_models(root._source_join) - table_names.extend([r.name for r in join_roots if r.name]) - - key_requirements = projection_utils.extract_requirements_from_keys( - keys=list(self.keys), - dimensions=merged_dimensions, - table=base_tbl, - table_names=table_names, - ) - - measure_requirements = projection_utils.extract_requirements_from_measures( - measures={name: _unwrap(fn) for name, fn in self.aggs.items()}, - table=base_tbl, - table_names=table_names, - ) - - combined = key_requirements.merge(measure_requirements) - - if hasattr(self.source, "required_columns"): - source_reqs = projection_utils.TableRequirements.from_dict(self.source.required_columns) - combined = combined.merge(source_reqs) - - return combined.to_dict() - - def to_untagged(self): - all_roots = _find_all_root_models(self.source) - - def find_join_in_tree(node): - """Find a SemanticJoinOp in the operation tree. - - All Relation subclasses have source attribute except leaf operations. - """ - if isinstance(node, SemanticJoinOp): - return node - if isinstance(node, SemanticTableOp): - # Leaf operations - no source to traverse - return None - if node.source is not None: - return find_join_in_tree(node.source) - return None - - def collect_filters_to_join(node): - """Collect filter predicates between this node and the join. - - Returns tuple of filter predicate wrappers found between the - aggregate and the underlying join/table. - """ - filters = [] - current = node - while current is not None: - match current: - case SemanticFilterOp(): - filters.append(current.predicate) - current = current.source - case SemanticJoinOp() | SemanticTableOp(): - break - case SemanticGroupByOp(): - current = current.source - case _ if hasattr(current, "source"): - current = current.source - case _: - break - return tuple(filters) - - join_op = find_join_in_tree(self.source) - - if join_op is None and isinstance(self.source, SemanticGroupByOp): - grouped_source = self.source.source - if isinstance(grouped_source, SemanticTableOp): - # SemanticTableOp always has _source_join attribute - join_op = grouped_source._source_join - - def has_prior_aggregate(node): - """Recursively check if there's a SemanticAggregateOp before any mutate.""" - if isinstance(node, SemanticAggregateOp): - return True - if isinstance(node, SemanticMutateOp): - return has_prior_aggregate(node.source) - if isinstance(node, SemanticGroupByOp): - return has_prior_aggregate(node.source) - if isinstance(node, SemanticTableOp | SemanticJoinOp): - return False - if hasattr(node, "source"): - return has_prior_aggregate(node.source) - return False - - is_post_agg = has_prior_aggregate(self.source) - collected_filters = collect_filters_to_join(self.source) - - def collect_mutates_to_join(node): - """Collect SemanticMutateOp.post dicts between here and the join.""" - mutates = [] - current = node - while current is not None: - match current: - case SemanticMutateOp(): - mutates.append(current.post) - current = current.source - case SemanticJoinOp() | SemanticTableOp(): - break - case _ if hasattr(current, "source"): - current = current.source - case _: - break - return tuple(mutates) - - # Dimension-only shortcut: when no measures are requested and all - # dimensions originate from a single joined table, query that table - # directly so that dimension members with no matching fact rows are - # still returned. (Fixes #224.) - if join_op is not None and not is_post_agg and not self.aggs: - shortcut = _dimension_only_source_table( - self.keys, all_roots, collected_filters, - ) - if shortcut is not None: - root_op, unprefixed_keys, dim_filters = shortcut - try: - tbl = _to_untagged(root_op) - root_dims = root_op.get_dimensions() - tbl = _mutate_dimensions_with_dependencies( - tbl, unprefixed_keys, root_dims, - ) - # Apply pre-aggregation filters on the dimension table. - for flt in dim_filters: - fn = _unwrap(flt) if hasattr(flt, "unwrap") else flt - tbl = tbl.filter(fn(tbl)) - result = tbl.select(unprefixed_keys).distinct() - # Rename columns to their prefixed (dotted) names so that - # downstream consumers see the expected column names. - prefix = root_op.name - rename_map = {f"{prefix}.{uk}": uk for uk in unprefixed_keys} - return result.rename(rename_map) - except Exception: - logger.debug( - "dimension-only shortcut failed for keys=%s; " - "falling back to standard path", - self.keys, - exc_info=True, - ) - - # Pre-aggregation path: when join_many is present, aggregate each - # source table's measures at its own grain before joining. - if join_op is not None and not is_post_agg: - join_tree_info = _collect_join_tree_info(join_op) - if join_tree_info.has_join_many: - collected_mutates = collect_mutates_to_join(self.source) - return self._to_untagged_with_preagg( - all_roots, - join_op, - join_tree_info, - filters=collected_filters, - mutates=collected_mutates, - ) - - # Deferred join path: when join_one dimension tables can be - # joined AFTER aggregation for better performance. - deferrable = _find_deferrable_joins( - join_op, self.keys, self.aggs, all_roots, join_tree_info, - filters=collected_filters, - ) - if deferrable: - return self._to_untagged_with_deferred_joins( - all_roots, - join_op, - join_tree_info, - deferrable, - filters=collected_filters, - ) - - # Only use the join optimization if there are no filters after the join - # Otherwise we'd skip the filter operations - needed_tables: dict[str, set[str]] = {} - if join_op is not None and not collected_filters: - # Build table-level requirements from prefixed keys and measure - # names so the join can prune tables that the query never touches. - # When all names are unprefixed (single-table or post-agg), the - # dict is empty and `or None` disables pruning — correct since - # there's nothing to prune in that case. - for name in (*self.keys, *self.aggs.keys()): - if "." in name: - tbl_prefix = name.split(".", 1)[0] - needed_tables.setdefault(tbl_prefix, set()).add(name) - tbl = join_op.to_untagged(parent_requirements=needed_tables or None) - else: - tbl = _to_untagged(self.source) - - merged_dimensions = _get_merged_fields(all_roots, "dimensions") - merged_base_measures = _get_merged_fields(all_roots, "measures") - merged_calc_measures = _get_merged_fields(all_roots, "calc_measures") - - if needed_tables: - # Keep validation behavior for queried tables even when join pruning - # removes unrelated siblings. Validate each referenced root table - # against its own leaf schema so wrapped join aliases do not affect - # the error surface for local dimensions. - for root in all_roots: - if root.name not in needed_tables: - continue - root_dimensions = _get_field_dict(root, "dimensions") - if root_dimensions: - _mutate_dimensions_with_dependencies( - _to_untagged(root), - root_dimensions.keys(), - root_dimensions, - ) - - tbl = _mutate_dimensions_with_dependencies( - tbl, - [k for k in self.keys if k in merged_dimensions], - merged_dimensions, - ) - - scope = ( - ColumnScope(_tbl=tbl) - if is_post_agg - else MeasureScope( - _tbl=tbl, - _known=list(merged_base_measures.keys()) + list(merged_calc_measures.keys()), - ) - ) - - plan = _build_aggregation_plan( - aggs=self.aggs, - keys=self.keys, - scope=scope, - is_post_agg=is_post_agg, - merged_base_measures=merged_base_measures, - merged_calc_measures=merged_calc_measures, - tbl=tbl, - ) - - if plan.calc_specs or plan.group_by_cols: - return compile_grouped_with_all( - tbl, - list(plan.group_by_cols), - dict(plan.agg_specs), - dict(plan.calc_specs), - requested_measures=list(plan.requested_measures), - ) - else: - return tbl.aggregate({name: fn(tbl) for name, fn in plan.agg_specs.items()}) - - def _to_untagged_with_preagg( - self, - all_roots: list, - join_op: SemanticJoinOp, - join_tree_info: _JoinTreeInfo, - filters: list | None = None, - mutates: tuple | None = None, - ): - """Pre-aggregate each source table's measures at its own grain, then join. - - This prevents fan-out inflation when ``join_many`` is used. - """ - merged_dimensions = _get_merged_fields(all_roots, "dimensions") - merged_base_measures = _get_merged_fields(all_roots, "measures") - merged_calc_measures = _get_merged_fields(all_roots, "calc_measures") - group_by_cols = list(self.keys) - - filters = filters or [] - mutates = mutates or () - - # Identify group-by keys that come from .mutate() rather than - # semantic dimensions — these need special handling. This is a - # process-of-elimination heuristic: any key that is not a known - # dimension, measure, or calc measure is assumed to originate - # from a .mutate() call in the operation chain. - mutated_gb_keys = frozenset( - k for k in group_by_cols - if k not in merged_dimensions - and k not in merged_base_measures - and k not in merged_calc_measures - ) - - # --- 1. Try to build the full joined table (for scope / dim bridge) --- - # Pre-agg needs all tables for dimension bridges — no pruning here. - try: - tbl = join_op.to_untagged(parent_requirements=None) - tbl = _mutate_dimensions_with_dependencies( - tbl, - [k for k in self.keys if k in merged_dimensions], - merged_dimensions, - ) - # Apply mutate operations so mutated group-by keys are available - # on the full joined table (needed for dimension bridges). - if mutated_gb_keys: - for mutate_post in mutates: - for name, fn_wrapped in mutate_post.items(): - if name in mutated_gb_keys: - try: - fn = _unwrap(fn_wrapped) - resolved = _resolve_expr(fn, tbl) - tbl = tbl.mutate(**{name: resolved}) - except (TypeError, KeyError, AttributeError, - ibis.common.exceptions.IbisTypeError): - pass - # Apply collected filters to the full joined table so that - # dimension bridges only include rows surviving the filter. - if filters: - from .convert import _Resolver - - for pred in filters: - pred_fn = _unwrap(pred) - resolver = _Resolver(tbl, merged_dimensions) - pred_expr = _resolve_expr(pred_fn, resolver) - tbl = tbl.filter(pred_expr) - except Exception: - # chasm / column collision – work without full join - logger.debug("full-join construction failed; using per-table path", exc_info=True) - tbl = None - - # --- 2. Build aggregation plan --- - if tbl is not None: - scope = MeasureScope( - _tbl=tbl, - _known=list(merged_base_measures.keys()) + list(merged_calc_measures.keys()), - ) - plan = _build_aggregation_plan( - aggs=self.aggs, - keys=self.keys, - scope=scope, - is_post_agg=False, - merged_base_measures=merged_base_measures, - merged_calc_measures=merged_calc_measures, - tbl=tbl, - ) - else: - # Derive plan directly from metadata (chasm fallback) - agg_specs = {} - for name, fn_wrapped in self.aggs.items(): - if name in merged_base_measures: - agg_specs[name] = _make_agg_callable(merged_base_measures[name]) - plan = _AggregationPlan( - agg_specs=FrozenDict(agg_specs), - calc_specs=FrozenDict({}), - requested_measures=tuple(self.aggs.keys()), - group_by_cols=tuple(self.keys), - ) - - # --- 3. Partition agg_specs by source table --- - partitioned = _partition_agg_specs_by_source(dict(plan.agg_specs), all_roots) - - # --- 4. Pre-aggregate each source table on its raw table --- - _preagg_results: list = [] - # Track MEAN measures decomposed into SUM + COUNT for correct re-agg - _decomposed_means: dict[str, tuple[str, str]] = {} - # Track correct re-aggregation op per measure (default "sum") - _reagg_ops: dict[str, str] = {} - # Track COUNT DISTINCT measures deferred past pre-aggregation. - # Value: (table_name, short_name, raw_tbl, measure_fn) - _deferred_count_distincts: dict[str, tuple] = {} - - for table_name, measures in partitioned.items(): - if table_name is None: - # Unprefixed – aggregate on the full join if available - if tbl is not None: - agg_exprs = {n: f(tbl) for n, f in measures.items()} - if group_by_cols: - r = tbl.group_by([tbl[c] for c in group_by_cols]).aggregate(**agg_exprs) - else: - r = tbl.aggregate(**agg_exprs) - _preagg_results.append(r) - continue - - table_op = join_tree_info.table_ops.get(table_name) - if table_op is None: - continue - - raw_tbl = _to_untagged(table_op) - - # Push applicable filters to per-table raw tables - if filters: - all_pushed = True - for pred in filters: - pred_fn = _unwrap(pred) - try: - pred_expr = pred_fn(raw_tbl) - raw_tbl = raw_tbl.filter(pred_expr) - except Exception: - all_pushed = False # filter references columns not on this table - - # If some filters couldn't be pushed (cross-table filters), - # restrict via join keys from the filtered full joined table - # or from filtered raw tables that own the filter columns. - if not all_pushed: - jk = join_tree_info.table_join_keys.get(table_name, set()) - if tbl is not None: - shared = sorted(jk & set(raw_tbl.columns) & set(tbl.columns)) - if shared: - key_bridge = tbl.select([tbl[c] for c in shared]).distinct() - preds = [raw_tbl[c] == key_bridge[c] for c in shared] - raw_tbl = raw_tbl.inner_join(key_bridge, preds).select(raw_tbl) - else: - # Chasm fallback: build key bridge from other raw tables - for other_name, other_op in join_tree_info.table_ops.items(): - if other_name == table_name: - continue - other_raw = _to_untagged(other_op) - can_filter = False - for pred in filters: - pred_fn = _unwrap(pred) - try: - pred_expr = pred_fn(other_raw) - other_raw = other_raw.filter(pred_expr) - can_filter = True - except Exception: - pass - if can_filter: - other_jk = join_tree_info.table_join_keys.get(other_name, set()) - shared = sorted( - jk & other_jk & set(raw_tbl.columns) & set(other_raw.columns) - ) - if shared: - key_bridge = other_raw.select( - [other_raw[c] for c in shared] - ).distinct() - preds = [raw_tbl[c] == key_bridge[c] for c in shared] - raw_tbl = raw_tbl.inner_join(key_bridge, preds).select(raw_tbl) - break - - # Apply mutated group-by keys to this raw table so they can - # participate in the per-table grain computation. - if mutated_gb_keys: - for mutate_post in mutates: - for name, fn_wrapped in mutate_post.items(): - if name in mutated_gb_keys and name not in raw_tbl.columns: - try: - fn = _unwrap(fn_wrapped) - resolved = _resolve_expr(fn, raw_tbl) - raw_tbl = raw_tbl.mutate(**{name: resolved}) - except (TypeError, KeyError, AttributeError, - ibis.common.exceptions.IbisTypeError): - pass - - table_measures = _get_field_dict(table_op, "measures") - table_dims = _get_field_dict(table_op, "dimensions") - # Captured after mutated-key application above so the grain - # computation can see the newly-added columns. - raw_columns = set(raw_tbl.columns) - - # Build agg expressions on the raw table - agg_exprs: dict = {} - for mname, _mfn in measures.items(): - short = mname.split(".", 1)[1] if "." in mname else mname - if short in table_measures: - expr = table_measures[short](raw_tbl) - # Decompose MEAN into SUM + COUNT for correct re-aggregation - if _is_mean_expr(expr): - mean_op = expr.op() - base_col = mean_op.arg.to_expr() - sum_col = f"_sum__{mname}" - count_col = f"_count__{mname}" - agg_exprs[sum_col] = base_col.sum() - agg_exprs[count_col] = base_col.count() - _decomposed_means[mname] = (sum_col, count_col) - elif _is_count_distinct_expr(expr): - # COUNT DISTINCT is immune to fan-out — defer past pre-agg - _deferred_count_distincts[mname] = ( - table_name, short, raw_tbl, table_measures[short], - ) - else: - _reagg_ops[mname] = _reagg_op_for_expr(expr) - agg_exprs[mname] = expr - - if not agg_exprs: - continue - - # --- Compute grain --- - if not group_by_cols: - # No group-by → scalar aggregate - pt = raw_tbl.aggregate(**agg_exprs) - # Recompute MEAN from SUM/COUNT for scalar results - for mname, (sc, cc) in _decomposed_means.items(): - if sc in pt.columns and cc in pt.columns: - pt = pt.mutate(**{mname: pt[sc] / pt[cc]}) - pt = pt.drop(sc, cc) - _preagg_results.append(pt) - continue - - # a) group-by dims that live on this table - _local_dims = [] - has_cross_table_gb = False - for gb_key in group_by_cols: - # Mutated columns are already on raw_tbl (applied above) - if gb_key in mutated_gb_keys: - if gb_key in raw_columns and gb_key not in _local_dims: - _local_dims.append(gb_key) - elif "." in gb_key: - prefix, short = gb_key.split(".", 1) - if prefix == table_name and short in table_dims: - dim_fn = table_dims[short] - if callable(dim_fn): - dim_expr = dim_fn(raw_tbl) - col_name = dim_expr.get_name() - if col_name == short and col_name in raw_columns: - # Simple column reference — use directly - if col_name not in _local_dims: - _local_dims.append(col_name) - elif col_name in raw_columns or short not in raw_columns: - # Derived dimension — materialize on raw_tbl - raw_tbl = raw_tbl.mutate(**{short: dim_expr}) - raw_columns = set(raw_tbl.columns) - if short not in _local_dims: - _local_dims.append(short) - elif prefix != table_name: - has_cross_table_gb = True - - join_keys = join_tree_info.table_join_keys.get(table_name, set()) - available_jk = tuple(jk for jk in sorted(join_keys) if jk in raw_columns) - - # b) if none found, use join keys; if cross-table gb, augment with them - match (_local_dims, has_cross_table_gb): - case ([], _): - grain = available_jk - case (_, True): - grain = tuple( - dict.fromkeys( - _local_dims + [jk for jk in available_jk if jk not in _local_dims] - ) - ) - case _: - grain = tuple(_local_dims) - - if grain: - _preagg_results.append( - raw_tbl.group_by([raw_tbl[c] for c in grain]).aggregate(**agg_exprs) - ) - else: - _preagg_results.append(raw_tbl.aggregate(**agg_exprs)) - - # Freeze mutable accumulators - preagg_results = tuple(_preagg_results) - decomposed_means = tuple(_decomposed_means.items()) - reagg_ops = tuple(_reagg_ops.items()) - - if not preagg_results and not _deferred_count_distincts: - if tbl is not None: - return tbl.aggregate({n: f(tbl) for n, f in plan.agg_specs.items()}) - raise ValueError("No aggregation results and full join unavailable") - - # --- 5. Combine pre-agg results --- - result = None - if preagg_results: - if not group_by_cols: - # Cross-join all scalar results - result = preagg_results[0] - for pt in preagg_results[1:]: - result = result.cross_join(pt) - elif tbl is not None: - result = self._join_preagg_with_dim_bridge( - preagg_results, - plan, - tbl, - group_by_cols, - decomposed_means=decomposed_means, - reagg_ops=reagg_ops, - ) - else: - # Chasm fallback with group-by: build minimal dim bridge from raw tables - result = self._build_minimal_dim_bridge( - preagg_results, - plan, - group_by_cols, - join_tree_info, - merged_dimensions, - decomposed_means=decomposed_means, - reagg_ops=reagg_ops, - ) - - # --- 5b. Compute deferred COUNT DISTINCT measures --- - # Optimisation: compute on the raw source table when all group-by - # columns are local (avoids scanning the fanned-out joined table). - # Fall back to the full joined table only for cross-table group-bys. - if _deferred_count_distincts: - cd_parts: list = [] - for mname, (_src_tbl_name, _short, src_raw, src_fn) in ( - _deferred_count_distincts.items() - ): - src_cols = set(src_raw.columns) - # Determine which group-by columns live on the raw source table - local_gb = [c for c in group_by_cols if c in src_cols] - if not group_by_cols or frozenset(local_gb) == frozenset(group_by_cols): - # All group-by cols are local (or scalar) — compute on raw table - cd_expr = src_fn(src_raw) - if local_gb: - cd_pt = src_raw.group_by( - [src_raw[c] for c in local_gb] - ).aggregate(**{mname: cd_expr}) - else: - cd_pt = src_raw.aggregate(**{mname: cd_expr}) - cd_parts.append(cd_pt) - else: - # Cross-table group-by — need the full joined table - if tbl is None: - raise ValueError( - "COUNT DISTINCT measures require the full joined table " - "for cross-table group-by but it is unavailable (chasm " - "fallback). Use join_one for reference tables or remove " - "count-distinct measures from this query." - ) - cd_expr = plan.agg_specs[mname](tbl) - gb_available = [c for c in group_by_cols if c in tbl.columns] - if gb_available: - cd_pt = tbl.group_by( - [tbl[c] for c in gb_available] - ).aggregate(**{mname: cd_expr}) - else: - cd_pt = tbl.aggregate(**{mname: cd_expr}) - cd_parts.append(cd_pt) - - # Merge count-distinct parts into result - for cd_pt in cd_parts: - cd_meas = [c for c in cd_pt.columns if c in _deferred_count_distincts] - cd_grain = [c for c in cd_pt.columns if c not in _deferred_count_distincts] - if result is None: - result = cd_pt - elif cd_grain: - common = [c for c in cd_grain if c in result.columns] - if common: - preds = [result[c] == cd_pt[c] for c in common] - result = result.left_join(cd_pt, preds).select( - [result] + [cd_pt[m] for m in cd_meas] - ) - else: - result = result.cross_join(cd_pt) - else: - result = result.cross_join(cd_pt) - - # --- 6. Apply calc_specs --- - if plan.calc_specs: - result = self._apply_calc_specs(result, plan, tbl) - - # --- 7. Select requested columns --- - available = frozenset(result.columns) - select_cols = tuple( - dict.fromkeys( - c - for c in (*plan.group_by_cols, *plan.requested_measures, *plan.calc_specs.keys()) - if c in available - ) - ) - if select_cols: - result = result.select([result[c] for c in select_cols]) - - return result - - def _to_untagged_with_deferred_joins( - self, - all_roots: list, - join_op: "SemanticJoinOp", - join_tree_info: "_JoinTreeInfo", - deferrable: list["_DeferrableJoin"], - filters: list | None = None, - ): - """Aggregate first, then LEFT JOIN deferred dimension tables. - - For ``join_one`` dimension lookups where the right table's PK matches - the join key and no measures from the right table are used, we can: - 1. Strip deferred tables from the join tree - 2. Add join keys to GROUP BY so they survive aggregation - 3. Aggregate on the core (non-deferred) tables - 4. LEFT JOIN deferred dimension tables onto the aggregated result - """ - filters = filters or [] - deferred_names = {d.table_name for d in deferrable} - - merged_dimensions = _get_merged_fields(all_roots, "dimensions") - merged_base_measures = _get_merged_fields(all_roots, "measures") - merged_calc_measures = _get_merged_fields(all_roots, "calc_measures") - - # --- 1. Rebuild join tree without deferred tables --- - def strip_deferred(node): - """Remove deferred right-side tables from the join tree.""" - if not isinstance(node, SemanticJoinOp): - return node - # Recurse left - new_left = strip_deferred(node.left) - # If right side is a deferred table, return just the left - if ( - isinstance(node.right, SemanticTableOp) - and node.right.name in deferred_names - ): - return new_left - return SemanticJoinOp( - left=new_left, - right=node.right, - how=node.how, - on=node.on, - cardinality=node.cardinality, - ) - - core_join = strip_deferred(join_op) - - # --- 2. Build core table and aggregate --- - if isinstance(core_join, SemanticJoinOp): - core_tbl = core_join.to_untagged(parent_requirements=None) - else: - core_tbl = _to_untagged(core_join) - - # Resolve group-by dimensions on the core table - core_group_keys = [] - for k in self.keys: - if k in deferred_names or any(k.startswith(f"{dn}.") for dn in deferred_names): - # Skip deferred table's own dimensions from group-by - # (they'll be added via post-agg join) - continue - core_group_keys.append(k) - - # Add left-side join keys to group-by so they survive aggregation - for d in deferrable: - for jk in d.join_keys_left: - if jk not in core_group_keys and jk in core_tbl.columns: - core_group_keys.append(jk) - - # Resolve dimensions for group-by on the core table - core_tbl = _mutate_dimensions_with_dependencies( - core_tbl, - [k for k in core_group_keys if k in merged_dimensions], - merged_dimensions, - ) - - # Apply filters - if filters: - from .convert import _Resolver - - for pred in filters: - pred_fn = _unwrap(pred) - try: - resolver = _Resolver(core_tbl, merged_dimensions) - pred_expr = _resolve_expr(pred_fn, resolver) - core_tbl = core_tbl.filter(pred_expr) - except Exception: - pass - - # Build aggregation expressions - scope = MeasureScope( - _tbl=core_tbl, - _known=list(merged_base_measures.keys()) + list(merged_calc_measures.keys()), - ) - plan = _build_aggregation_plan( - aggs=self.aggs, - keys=tuple(core_group_keys), - scope=scope, - is_post_agg=False, - merged_base_measures=merged_base_measures, - merged_calc_measures=merged_calc_measures, - tbl=core_tbl, - ) - - # Execute aggregation on core table - if plan.group_by_cols: - gb_exprs = [] - for col in plan.group_by_cols: - if col in core_tbl.columns: - gb_exprs.append(core_tbl[col]) - elif col in merged_dimensions: - dim_fn = merged_dimensions[col] - gb_exprs.append(dim_fn(core_tbl).name(col)) - - agg_exprs = {name: fn(core_tbl) for name, fn in plan.agg_specs.items()} - result = core_tbl.group_by(gb_exprs).aggregate(**agg_exprs) - else: - agg_exprs = {name: fn(core_tbl) for name, fn in plan.agg_specs.items()} - result = core_tbl.aggregate(**agg_exprs) - - # Handle calculated measures - if plan.calc_specs: - from .compile_all import compile_calc_measures - - result = compile_calc_measures(result, plan.calc_specs) - - # --- 3. LEFT JOIN deferred dimension tables --- - for d in deferrable: - dim_tbl = _to_untagged(d.table_op) - - # Reuse the original join predicate to preserve key pairing - # (avoids mismatch from independently sorting left/right keys) - join_preds = d.on_predicate(result, dim_tbl) - - # Compute deferred dimension columns on the dimension table. - # For direct columns (e.g., cc_name) this is a no-op rename. - # For derived expressions (e.g., cc_name.upper()) this materializes - # the computed column so it's available after the join. - dim_cols_to_add = [] - right_dims = d.table_op.get_dimensions() - for dim_name in d.deferred_dims: - short = dim_name.split(".", 1)[1] if "." in dim_name else dim_name - if short in right_dims: - dim_fn = right_dims[short] - if callable(dim_fn): - try: - expr = dim_fn(dim_tbl) - col_name = expr.get_name() - if col_name in dim_tbl.columns: - # Direct column — use as-is - dim_cols_to_add.append((dim_name, col_name)) - else: - # Derived expression — mutate onto dim table - # Use a temp name to avoid collisions - temp_name = f"__deferred_{short}" - dim_tbl = dim_tbl.mutate(**{temp_name: expr}) - dim_cols_to_add.append((dim_name, temp_name)) - except Exception: - pass - - if dim_cols_to_add: - # Perform the LEFT JOIN - joined = result.left_join(dim_tbl, join_preds) - # Select: all from result + deferred dim columns renamed to prefixed names - select_exprs = [result] - for prefixed_name, raw_col in dim_cols_to_add: - select_exprs.append(dim_tbl[raw_col].name(prefixed_name)) - result = joined.select(select_exprs) - - # --- 4. Select only originally requested columns --- - original_cols = list(self.keys) + list(self.aggs.keys()) - # Add deferred dim names (they were requested) - for d in deferrable: - for dd in d.deferred_dims: - if dd not in original_cols: - original_cols.append(dd) - # Also keep calc measure names - if plan.calc_specs: - for cm in plan.calc_specs: - if cm not in original_cols: - original_cols.append(cm) - - available = frozenset(result.columns) - select_cols = [c for c in original_cols if c in available] - if select_cols and set(select_cols) != available: - result = result.select([result[c] for c in select_cols]) - - return result - - # -- helpers for _to_untagged_with_preagg -------------------------------- - - @staticmethod - def _join_preagg_with_dim_bridge( - preagg_results, plan, tbl, group_by_cols, decomposed_means=(), reagg_ops=() - ): - """Join pre-aggregated tables using per-table dimension bridges. - - ``decomposed_means`` and ``reagg_ops`` are tuples of (key, value) pairs. - """ - from .compile_all import _join_tables - - reagg_map = dict(reagg_ops) - # Include decomposed auxiliary columns in measure names - aux_cols = frozenset(c for _, (sc, cc) in decomposed_means for c in (sc, cc)) - measure_names = ( - frozenset(plan.agg_specs.keys()) | frozenset(plan.calc_specs.keys()) | aux_cols - ) - gb_set = frozenset(group_by_cols) - - def _rejoin_one(pt): - pt_grain = tuple(c for c in pt.columns if c not in measure_names) - pt_meas = tuple(c for c in pt.columns if c in measure_names) - - if gb_set <= frozenset(pt_grain): - # Already has group-by columns — re-aggregate if over-grouped - if frozenset(pt_grain) != gb_set: - re_aggs = {m: _build_reagg(pt[m], reagg_map.get(m, "sum")) for m in pt_meas} - return pt.group_by([pt[c] for c in group_by_cols]).aggregate(**re_aggs) - return pt - - if not pt_grain: - return pt - - # Build a per-table dim bridge with ONLY this table's grain cols - bridge_cols = tuple( - dict.fromkeys(c for c in (*group_by_cols, *pt_grain) if c in tbl.columns) - ) - if not bridge_cols: - return pt - - dim_bridge = tbl.select([tbl[c] for c in bridge_cols]).distinct() - common = tuple(c for c in pt_grain if c in dim_bridge.columns) - if not common: - return pt - - preds = [dim_bridge[c] == pt[c] for c in common] - joined_pt = dim_bridge.left_join(pt, preds).select( - [dim_bridge] + [pt[c] for c in pt_meas] - ) - gb_avail = tuple(c for c in group_by_cols if c in joined_pt.columns) - if gb_avail: - re_aggs = { - m: _build_reagg(joined_pt[m], reagg_map.get(m, "sum")) - for m in pt_meas - if m in joined_pt.columns - } - if re_aggs: - joined_pt = joined_pt.group_by([joined_pt[c] for c in gb_avail]).aggregate( - **re_aggs - ) - return joined_pt - - rejoined = tuple(_rejoin_one(pt) for pt in preagg_results) - result = _join_tables(group_by_cols, list(rejoined)) - - # Recompute MEAN from decomposed SUM/COUNT after combining - for mname, (sc, cc) in decomposed_means: - if sc in result.columns and cc in result.columns: - result = result.mutate(**{mname: result[sc] / result[cc]}) - result = result.drop(sc, cc) - - return result - - @staticmethod - def _build_minimal_dim_bridge( - preagg_results, - plan, - group_by_cols, - join_tree_info, - merged_dimensions, - decomposed_means=(), - reagg_ops=(), - ): - """Build dim bridges from raw tables when full join is unavailable. - - When the full ibis join fails (e.g. column collisions with 3+ - ``join_many`` arms sharing the same key), we build per-table dimension - bridges using only the raw single-table data already captured in - *join_tree_info*. This avoids the ibis collision entirely because we - never join more than two tables at once. - - ``decomposed_means`` and ``reagg_ops`` are tuples of (key, value) pairs. - """ - from .compile_all import _join_tables - - reagg_map = dict(reagg_ops) - aux_cols = frozenset(c for _, (sc, cc) in decomposed_means for c in (sc, cc)) - measure_names = ( - frozenset(plan.agg_specs.keys()) | frozenset(plan.calc_specs.keys()) | aux_cols - ) - gb_set = frozenset(group_by_cols) - - def _bridge_one_preagg(pt): - pt_grain = tuple(c for c in pt.columns if c not in measure_names) - pt_meas = tuple(c for c in pt.columns if c in measure_names) - - # (a) Pre-agg already carries all group-by cols — re-aggregate. - if gb_set <= frozenset(pt_grain): - if frozenset(pt_grain) != gb_set: - re_aggs = {m: _build_reagg(pt[m], reagg_map.get(m, "sum")) for m in pt_meas} - return pt.group_by([pt[c] for c in group_by_cols]).aggregate(**re_aggs) - return pt - - # (b) Scalar pre-agg — nothing to bridge. - if not pt_grain: - return pt - - # (c) Bridge missing group-by dims from raw tables. - bridged = pt - for gb_col in group_by_cols: - if gb_col in bridged.columns: - continue - bridged = _attach_dim_column( - bridged, - gb_col, - measure_names, - join_tree_info, - merged_dimensions, - ) - - # Re-aggregate onto the requested group-by granularity. - gb_avail = tuple(c for c in group_by_cols if c in bridged.columns) - re_aggs = { - m: _build_reagg(bridged[m], reagg_map.get(m, "sum")) - for m in pt_meas - if m in bridged.columns - } - match (gb_avail, bool(re_aggs)): - case ((), _) | (_, False): - return bridged - case _: - return bridged.group_by([bridged[c] for c in gb_avail]).aggregate(**re_aggs) - - rejoined = tuple(_bridge_one_preagg(pt) for pt in preagg_results) - result = _join_tables(group_by_cols, list(rejoined)) - - # Recompute MEAN from decomposed SUM/COUNT after combining - for mname, (sc, cc) in decomposed_means: - if sc in result.columns and cc in result.columns: - result = result.mutate(**{mname: result[sc] / result[cc]}) - result = result.drop(sc, cc) - - return result - - @staticmethod - def _apply_calc_specs(result, plan, tbl): - """Apply calculated measure specs (ratios, percent-of-total, etc.).""" - from .compile_all import _collect_all_refs, _compile_formula - - needed_totals: set[str] = set() - for ast in plan.calc_specs.values(): - _collect_all_refs(ast, needed_totals) - - if needed_totals: - totals_aggs = {ref: result[ref].sum() for ref in needed_totals if ref in result.columns} - all_tbl = result.aggregate(**totals_aggs) if totals_aggs else None - else: - all_tbl = None - - out = result.cross_join(all_tbl) if all_tbl is not None else result - calc_cols = { - name: _compile_formula(ast, out, all_tbl, tbl if tbl is not None else out) - for name, ast in plan.calc_specs.items() - } - return out.mutate(**calc_cols) - - -class SemanticMutateOp(Relation): - source: Relation - post: dict[ - str, - Callable, - ] # Transformed to FrozenDict[str, _CallableWrapper] in __init__ - nested_columns: tuple[ - str, - ..., - ] = () # Inherited from source if it has nested columns - - def __init__( - self, - source: Relation, - post: dict[str, Callable] | None, - nested_columns: tuple[str, ...] = (), - ) -> None: - frozen_post = FrozenDict( - {name: _ensure_wrapped(fn) for name, fn in (post or {}).items()}, - ) - source_nested = nested_columns if nested_columns else getattr(source, "nested_columns", ()) - - super().__init__( - source=Relation.__coerce__(source), - post=frozen_post, - nested_columns=source_nested, - ) - - def __repr__(self) -> str: - return _semantic_repr(self) - - @property - def values(self) -> FrozenOrderedDict[str, Any]: - return self.source.values - - @property - def schema(self) -> Schema: - return self.source.schema - - def to_untagged(self): - agg_tbl = _to_untagged(self.source) - - # Process mutations incrementally so each can reference previous ones - # This allows: .mutate(rank=..., is_other=lambda t: t["rank"] > 5) - current_tbl = agg_tbl - for name, fn_wrapped in self.post.items(): - proxy = MeasureScope(_tbl=current_tbl, _known=[], _post_agg=True) - resolved = _resolve_expr(_unwrap(fn_wrapped), proxy) - - new_col = resolved.name(name) - current_tbl = current_tbl.mutate([new_col]) - - return current_tbl - - def get_dimensions(self) -> Mapping[str, Dimension]: - """Get dictionary of dimensions from source.""" - return self.source.get_dimensions() - - def get_measures(self) -> Mapping[str, Measure]: - """Get dictionary of measures from source.""" - return self.source.get_measures() - - def get_calculated_measures(self) -> Mapping[str, Any]: - """Get dictionary of calculated measures from source.""" - return self.source.get_calculated_measures() - - -class SemanticUnnestOp(Relation): - """Unnest an array column, expanding rows (like Malloy's nested data pattern).""" - - source: Relation - column: str - - def __repr__(self) -> str: - return _semantic_repr(self) - - @property - def schema(self) -> Schema: - # After unnesting, the schema changes - the array column is replaced by its element schema - # For now, delegate to source schema (ideally we'd update it) - return self.source.schema - - @property - def values(self) -> FrozenDict: - return FrozenDict({}) - - def to_untagged(self): - """Convert to Ibis expression with functional struct unpacking. - - Uses pure helper functions to extract struct fields when unnesting - produces struct columns that need to be expanded. - """ - - def build_struct_fields(col_expr, col_type): - """Pure function: build dict of struct field selections.""" - return {name: col_expr[name] for name in col_type.names} - - def unpack_struct_if_needed(unnested_tbl, column_name): - """Conditionally unpack struct fields into top-level columns.""" - if column_name not in unnested_tbl.columns: - return unnested_tbl - - col_expr = unnested_tbl[column_name] - col_type = col_expr.type() - - # Only Struct types have fields to unpack - if isinstance(col_type, dt.Struct) and col_type.fields: - struct_fields = build_struct_fields(col_expr, col_type) - return unnested_tbl.select(unnested_tbl, **struct_fields) - - return unnested_tbl - - tbl = _to_untagged(self.source) - - if self.column not in tbl.columns: - raise ValueError(f"Column '{self.column}' not found in table") - - try: - unnested = tbl.unnest(self.column) - except Exception as e: - raise ValueError(f"Failed to unnest column '{self.column}': {e}") from e - - return unpack_struct_if_needed(unnested, self.column) - - def get_dimensions(self) -> Mapping[str, Dimension]: - """Get dictionary of dimensions from source.""" - return self.source.get_dimensions() - - def get_measures(self) -> Mapping[str, Measure]: - """Get dictionary of measures from source.""" - return self.source.get_measures() - - def get_calculated_measures(self) -> Mapping[str, Any]: - """Get dictionary of calculated measures from source.""" - return self.source.get_calculated_measures() - - -class SemanticJoinOp(Relation): - left: Relation - right: Relation - how: str - on: ( - Callable[[Any, Any], Any] | None - ) # Returns BooleanValue from either ibis or xorq.vendor.ibis - cardinality: str # "one", "many", or "cross" - - def __init__( - self, - left: Relation, - right: Relation, - how: str = "left", - on: Callable[[Any, Any], Any] | None = None, - cardinality: str = "one", - ) -> None: - super().__init__( - left=Relation.__coerce__(left), - right=Relation.__coerce__(right), - how=how, - on=on, - cardinality=cardinality, - ) - - def __repr__(self) -> str: - return _semantic_repr(self) - - @property - def values(self) -> FrozenOrderedDict[str, Any]: - vals: dict[str, Any] = {} - vals.update(self.left.values) - vals.update(self.right.values) - return FrozenOrderedDict(vals) - - @property - def schema(self): - """Get schema of semantic table. - - Uses runtime imports to handle both regular ibis and xorq (which vendors ibis). - Converts dtypes to strings to allow Schema to parse them into the correct dtype objects. - """ - fields_dict = {name: str(v.dtype) for name, v in self.values.items()} - return _make_schema(fields_dict) - - def get_dimensions(self) -> Mapping[str, Dimension]: - """Get dictionary of dimensions with metadata.""" - all_roots = _find_all_root_models(self) - return _merge_fields_with_prefixing( - all_roots, - lambda r: _get_field_dict(r, "dimensions"), - source=self, # Pass self to extract join keys - ) - - def get_measures(self) -> Mapping[str, Measure]: - """Get dictionary of base measures with metadata.""" - all_roots = _find_all_root_models(self) - return _merge_fields_with_prefixing( - all_roots, - lambda r: _get_field_dict(r, "measures"), - source=self, - ) - - def get_calculated_measures(self) -> Mapping[str, Any]: - """Get dictionary of calculated measures with metadata.""" - all_roots = _find_all_root_models(self) - return _merge_fields_with_prefixing( - all_roots, - lambda r: _get_field_dict(r, "calc_measures"), - source=self, - ) - - @property - def dimensions(self) -> tuple[str, ...]: - """Get tuple of dimension names.""" - return tuple(self.get_dimensions().keys()) - - @property - def _dims(self) -> dict[str, Dimension]: - return dict(self.get_dimensions()) - - @property - def _base_measures(self) -> dict[str, Measure]: - return dict(self.get_measures()) - - @property - def _calc_measures(self) -> dict[str, Any]: - return dict(self.get_calculated_measures()) - - @property - def calc_measures(self) -> dict[str, Any]: - """Get calculated measures as dict (for consistency with SemanticModel).""" - return dict(self.get_calculated_measures()) - - @property - def measures(self) -> tuple[str, ...]: - return tuple(self.get_measures().keys()) + tuple( - self.get_calculated_measures().keys(), - ) - - @property - def json_definition(self) -> Mapping[str, Any]: - return _build_json_definition(self.get_dimensions(), self.get_measures(), None) - - @property - def name(self) -> str | None: - return None - - @property - def description(self) -> str | None: - """Get description for joined model by combining root model descriptions.""" - roots = _find_all_root_models(self) - base_descriptions = [] - for root in roots: - root_name = getattr(root, "name", None) or "unnamed" - root_desc = getattr(root, "description", None) - if root_desc: - base_descriptions.append(f"{root_name} ({root_desc})") - else: - base_descriptions.append(root_name) - if base_descriptions: - return "Joined model combining: " + ", ".join(base_descriptions) - return None - - @property - def table(self): - return self.to_untagged() - - def query( - self, - dimensions: Sequence[str] | None = None, - measures: Sequence[str] | None = None, - filters: list | None = None, - order_by: Sequence[tuple[str, str]] | None = None, - limit: int | None = None, - time_grain: str | None = None, - time_range: dict[str, str] | None = None, - having: list | None = None, - ): - from .query import query as build_query - - return build_query( - semantic_table=self, - dimensions=dimensions, - measures=measures, - filters=filters, - order_by=order_by, - limit=limit, - time_grain=time_grain, - time_range=time_range, - having=having, - ) - - def with_dimensions(self, **dims) -> SemanticTable: - return _semantic_table( - table=self.to_untagged(), - dimensions={**self.get_dimensions(), **dims}, - measures=self.get_measures(), - calc_measures=self.get_calculated_measures(), - name=None, - ) - - def with_measures(self, **meas) -> SemanticTable: - joined_tbl = self.to_untagged() - all_known = ( - list(self.get_measures().keys()) - + list(self.get_calculated_measures().keys()) - + list(meas.keys()) - ) - scope = MeasureScope(_tbl=joined_tbl, _known=all_known) - - new_base, new_calc = ( - dict(self.get_measures()), - dict(self.get_calculated_measures()), - ) - for name, fn_or_expr in meas.items(): - kind, value = _classify_measure(fn_or_expr, scope, name) - (new_calc if kind == "calc" else new_base)[name] = value - - return _semantic_table( - table=joined_tbl, - dimensions=self.get_dimensions(), - measures=new_base, - calc_measures=new_calc, - name=None, - _source_join=self, # Pass join reference for projection pushdown - ) - - def group_by(self, *keys: str) -> SemanticGroupBy: - from .expr import SemanticGroupBy - - return SemanticGroupBy(source=self, keys=keys) - - def filter(self, predicate: Callable) -> SemanticFilter: - from .expr import SemanticFilter - - return SemanticFilter(source=self, predicate=predicate) - - def join_one( - self, - other: SemanticTable, - on: Callable[[Any, Any], ir.BooleanValue], - how: str = "left", - ): - """Join with one-to-one relationship semantics (left outer join).""" - from .expr import SemanticJoin - - return SemanticJoin( - left=self, - right=other.op(), - on=on, - how=how, - cardinality="one", - ) - - def join_many( - self, - other: SemanticTable, - on: Callable[[Any, Any], ir.BooleanValue], - how: str = "left", - ): - """Join with one-to-many relationship semantics.""" - from .expr import SemanticJoin - - return SemanticJoin( - left=self, - right=other.op(), - on=on, - how=how, - cardinality="many", - ) - - def join_cross(self, other: SemanticTable): - """Cross join (Cartesian product) with another semantic model.""" - from .expr import SemanticJoin - - return SemanticJoin( - left=self, - right=other.op(), - on=None, - how="cross", - cardinality="cross", - ) - - def join(self, *args, **kwargs): - """Deprecated: Use join_one(), join_many(), or join_cross() instead.""" - raise TypeError(_JOIN_REMOVED_MESSAGE) - - def index( - self, - selector: str | list[str] | Callable | None = None, - by: str | None = None, - sample: int | None = None, - ) -> SemanticIndexOp: - """Create an index for search/discovery. - - Supports ibis selectors (s.all(), s.cols(), etc.). - """ - - # Handle ibis selectors - processed_selector = selector - if selector is not None and "ibis.selectors" in str(type(selector).__module__): - # Handle s.all() - select all columns - if type(selector).__name__ == "AllColumns": - processed_selector = None - # Handle s.cols() - select specific columns - elif type(selector).__name__ == "Cols": - # Extract column names from the Cols selector - processed_selector = sorted(selector.names) - # For other selectors, keep as-is - else: - processed_selector = selector - - return SemanticIndexOp(source=self, selector=processed_selector, by=by, sample=sample) - - def _collect_leaf_table_names(self) -> set[str]: - """Collect names of all leaf (base) tables in this join tree.""" - tables = set() - - if isinstance(self.left, SemanticJoinOp): - tables |= self.left._collect_leaf_table_names() - else: - left_name = getattr(self.left, "name", None) - if left_name: - tables.add(left_name) - - if isinstance(self.right, SemanticJoinOp): - tables |= self.right._collect_leaf_table_names() - else: - right_name = getattr(self.right, "name", None) - if right_name: - tables.add(right_name) - - return tables - - @property - def required_columns(self): - """ - Column requirements for projection pushdown. - - This property makes projection pushdown intrinsic to the join operation, - similar to how `schema` is intrinsic to a relation. - - Computes what columns are needed from each leaf table based on: - 1. Columns needed for measures defined on the joined tables - 2. Join key columns - - Returns: - Dict mapping table names to sets of required column names. - """ - return self._compute_required_columns() - - def _compute_required_columns(self, parent_requirements: dict[str, set[str]] | None = None): - """ - Compute column requirements for projection pushdown. - - Args: - parent_requirements: Optional dict of specific columns requested by parent operations. - - Returns: - Dict mapping table names to sets of required column names. - """ - # Start with parent requirements using immutable TableRequirements - requirements = projection_utils.TableRequirements.from_dict( - parent_requirements if parent_requirements else {} - ) - - # Get all root models in join tree - all_roots = _find_all_root_models(self) - - # Collect leaf tables - def collect_leaf_tables(node): - if isinstance(node, SemanticJoinOp): - return collect_leaf_tables(node.left) + collect_leaf_tables(node.right) - table_name = getattr(node, "name", None) - return [(table_name, _to_untagged(node))] if table_name else [] - - leaf_tables = collect_leaf_tables(self) - - # Group measures by table - measures_by_table = {} - for root in all_roots: - root_measures = _get_field_dict(root, "measures") - if root.name: - measures_by_table[root.name] = root_measures - else: - # Parse prefixed measures (e.g., "marketing.spend") - for measure_name, measure_obj in root_measures.items(): - if "." in measure_name: - table_name = measure_name.split(".", 1)[0] - if table_name not in measures_by_table: - measures_by_table[table_name] = {} - measures_by_table[table_name][measure_name] = measure_obj - - # Extract columns needed by measures (using immutable operations) - for table_name, table_ibis in leaf_tables: - if table_name in measures_by_table: - for measure_obj in measures_by_table[table_name].values(): - # All measure_obj values are Measure instances with expr attribute - measure_fn = measure_obj.expr - if callable(measure_fn): - cols = projection_utils.extract_columns_from_callable_safe( - measure_fn, table_ibis - ) - if cols: - requirements = requirements.add_columns(table_name, cols) - - # Extract and add join key columns - if self.on is not None: - # Get full schema for join key extraction - temp_left = ( - self.left.to_untagged() - if isinstance(self.left, SemanticJoinOp) - else _to_untagged(self.left) - ) - temp_right = ( - self.right.to_untagged() - if isinstance(self.right, SemanticJoinOp) - else _to_untagged(self.right) - ) - - join_keys_result = _extract_join_key_columns(self.on, temp_left, temp_right) - - if join_keys_result.is_success(): - # Add join keys to leaf tables (immutable operations) - if isinstance(self.left, SemanticJoinOp): - for col in join_keys_result.left_columns: - for leaf_name in self.left._collect_leaf_table_names(): - leaf_table = self._get_leaf_table_by_name(self.left, leaf_name) - if leaf_table and col in _to_untagged(leaf_table).columns: - requirements = requirements.add_columns(leaf_name, {col}) - else: - left_name = getattr(self.left, "name", None) - if left_name: - requirements = requirements.add_columns( - left_name, join_keys_result.left_columns - ) - - if isinstance(self.right, SemanticJoinOp): - for col in join_keys_result.right_columns: - for leaf_name in self.right._collect_leaf_table_names(): - leaf_table = self._get_leaf_table_by_name(self.right, leaf_name) - if leaf_table and col in _to_untagged(leaf_table).columns: - requirements = requirements.add_columns(leaf_name, {col}) - else: - right_name = getattr(self.right, "name", None) - if right_name: - requirements = requirements.add_columns( - right_name, join_keys_result.right_columns - ) - - return requirements.to_dict() - - def _get_leaf_table_by_name(self, join_op: SemanticJoinOp, target_name: str): - """Find a leaf table by name in a join tree.""" - if isinstance(join_op.left, SemanticJoinOp): - result = self._get_leaf_table_by_name(join_op.left, target_name) - if result is not None: - return result - else: - left_name = getattr(join_op.left, "name", None) - if left_name == target_name: - return join_op.left - - if isinstance(join_op.right, SemanticJoinOp): - result = self._get_leaf_table_by_name(join_op.right, target_name) - if result is not None: - return result - else: - right_name = getattr(join_op.right, "name", None) - if right_name == target_name: - return join_op.right - - return None - - def _collect_join_keys_for_leaves(self) -> dict[str, set[str]]: - """Collect join keys needed by each leaf table. - - For nested joins, we trace join keys back to their source leaf tables. - Returns dict mapping leaf table names to sets of columns needed for joins. - """ - join_columns: dict[str, set[str]] = {} - - # Recursively collect from nested joins - if isinstance(self.left, SemanticJoinOp): - nested_keys = self.left._collect_join_keys_for_leaves() - for table_name, cols in nested_keys.items(): - existing = join_columns.get(table_name, set()) - join_columns[table_name] = existing | cols - - if isinstance(self.right, SemanticJoinOp): - nested_keys = self.right._collect_join_keys_for_leaves() - for table_name, cols in nested_keys.items(): - existing = join_columns.get(table_name, set()) - join_columns[table_name] = existing | cols - - # Add join keys for THIS level - if self.on is not None: - # Convert without projection to get full schema - temp_left = ( - self.left.to_untagged(parent_requirements=None) - if isinstance(self.left, SemanticJoinOp) - else _to_untagged(self.left) - ) - temp_right = ( - self.right.to_untagged(parent_requirements=None) - if isinstance(self.right, SemanticJoinOp) - else _to_untagged(self.right) - ) - - join_keys = _extract_join_key_columns(self.on, temp_left, temp_right) - - if join_keys.is_success(): - # Add join keys to the appropriate leaf tables - if not isinstance(self.left, SemanticJoinOp): - # Left is a leaf table - left_name = getattr(self.left, "name", None) - if left_name: - existing = join_columns.get(left_name, set()) - join_columns[left_name] = existing | join_keys.left_columns - else: - # Left is a nested join - need to map columns back to source tables - # Get all leaf tables from the nested join and their schemas - left_leaves = self.left._collect_leaf_table_names() - for col in join_keys.left_columns: - # Add column to each leaf table that actually has this column - for table_name in left_leaves: - if table_name: - # Check if this table actually has this column - # We do this by converting the table and checking its schema - leaf_table = self._get_leaf_table_by_name(self.left, table_name) - if leaf_table is not None: - leaf_ibis = _to_untagged(leaf_table) - if col in leaf_ibis.columns: - existing = join_columns.get(table_name, set()) - join_columns[table_name] = existing | {col} - - if not isinstance(self.right, SemanticJoinOp): - # Right is a leaf table - right_name = getattr(self.right, "name", None) - if right_name: - existing = join_columns.get(right_name, set()) - join_columns[right_name] = existing | join_keys.right_columns - else: - # Right is a nested join - right_leaves = self.right._collect_leaf_table_names() - for col in join_keys.right_columns: - for table_name in right_leaves: - if table_name: - leaf_table = self._get_leaf_table_by_name(self.right, table_name) - if leaf_table is not None: - leaf_ibis = _to_untagged(leaf_table) - if col in leaf_ibis.columns: - existing = join_columns.get(table_name, set()) - join_columns[table_name] = existing | {col} - - return join_columns - - def _augment_parent_requirements_for_pruning( - self, - parent_requirements: dict[str, set[str]] | None, - ) -> dict[str, set[str]] | None: - """Expand table requirements with bridge tables needed by descendant joins.""" - if parent_requirements is None: - return None - - augmented = {table: set(cols) for table, cols in parent_requirements.items()} - - right_tables = ( - self.right._collect_leaf_table_names() - if isinstance(self.right, SemanticJoinOp) - else {getattr(self.right, "name", None)} - {None} - ) - if augmented.keys() & right_tables and self.on is not None: - temp_left = ( - self.left.to_untagged(parent_requirements=None) - if isinstance(self.left, SemanticJoinOp) - else _to_untagged(self.left) - ) - temp_right = ( - self.right.to_untagged(parent_requirements=None) - if isinstance(self.right, SemanticJoinOp) - else _to_untagged(self.right) - ) - join_keys = _extract_join_key_columns(self.on, temp_left, temp_right) - if join_keys.is_success(): - join_key_columns = join_keys.left_columns | join_keys.right_columns - if isinstance(self.left, SemanticJoinOp): - for col in join_keys.left_columns: - for leaf_name in self.left._collect_leaf_table_names(): - leaf_table = self._get_leaf_table_by_name(self.left, leaf_name) - if leaf_table is not None and col in _to_untagged(leaf_table).columns: - augmented.setdefault(leaf_name, set()) - else: - left_name = getattr(self.left, "name", None) - if left_name: - augmented.setdefault(left_name, set()) + # Push applicable filters to per-table raw tables + if filters: + all_pushed = True + for pred in filters: + pred_fn = _unwrap(pred) + try: + pred_expr = pred_fn(raw_tbl) + raw_tbl = raw_tbl.filter(pred_expr) + except Exception: + all_pushed = False # filter references columns not on this table - right_needed_tables = right_tables & augmented.keys() - if isinstance(self.left, SemanticJoinOp) and right_needed_tables: - left_leaf_names = self.left._collect_leaf_table_names() - right_needed_leaf_names = [ - leaf_name for leaf_name in right_needed_tables - if self._get_leaf_table_by_name(self, leaf_name) is not None - ] - for left_leaf_name in left_leaf_names: - left_leaf = self._get_leaf_table_by_name(self.left, left_leaf_name) - if left_leaf is None: - continue - left_columns = set(_to_untagged(left_leaf).columns) - join_key_columns - if not left_columns: - continue - for right_leaf_name in right_needed_leaf_names: - right_leaf = self._get_leaf_table_by_name(self, right_leaf_name) - if right_leaf is None: + # If some filters couldn't be pushed (cross-table filters), + # restrict via join keys from the filtered full joined table + # or from filtered raw tables that own the filter columns. + if not all_pushed: + jk = join_tree_info.table_join_keys.get(table_name, set()) + if tbl is not None: + shared = sorted(jk & set(raw_tbl.columns) & set(tbl.columns)) + if shared: + key_bridge = tbl.select([tbl[c] for c in shared]).distinct() + preds = [raw_tbl[c] == key_bridge[c] for c in shared] + raw_tbl = raw_tbl.inner_join(key_bridge, preds).select(raw_tbl) + else: + # Chasm fallback: build key bridge from other raw tables + for other_name, other_op in join_tree_info.table_ops.items(): + if other_name == table_name: continue - right_columns = set(_to_untagged(right_leaf).columns) - join_key_columns - if left_columns & right_columns: - augmented.setdefault(left_leaf_name, set()) - break - - if isinstance(self.left, SemanticJoinOp): - augmented = self.left._augment_parent_requirements_for_pruning(augmented) or augmented - if isinstance(self.right, SemanticJoinOp): - augmented = self.right._augment_parent_requirements_for_pruning(augmented) or augmented - - return augmented - - def _should_prune_right(self, parent_requirements: dict[str, set[str]] | None) -> bool: - """Return whether the right side can be skipped for the given requirements.""" - if ( - parent_requirements is None - or self.cardinality != "one" - or self.how != "left" - ): - return False - - needed_tables = frozenset(parent_requirements.keys()) - right_tables = ( - self.right._collect_leaf_table_names() - if isinstance(self.right, SemanticJoinOp) - else {getattr(self.right, "name", None)} - {None} - ) - if not right_tables or right_tables & needed_tables: - return False - - # If an earlier sibling shares non-join columns with a still-needed - # table, pruning it changes the later table's rname-based aliases - # (e.g. ``state_right2`` vs ``state_right``). Keep the sibling so - # wrapped dimensions keep resolving to the expected columns. - remaining_tables = needed_tables - right_tables - if remaining_tables: - temp_left = ( - self.left.to_untagged(parent_requirements=None) - if isinstance(self.left, SemanticJoinOp) - else _to_untagged(self.left) - ) - temp_right = ( - self.right.to_untagged(parent_requirements=None) - if isinstance(self.right, SemanticJoinOp) - else _to_untagged(self.right) - ) - join_keys = _extract_join_key_columns(self.on, temp_left, temp_right) - join_key_columns = ( - (join_keys.left_columns | join_keys.right_columns) - if self.on is not None and join_keys.is_success() - else set() - ) - - for right_name in right_tables: - right_leaf = self._get_leaf_table_by_name(self, right_name) - if right_leaf is None: - continue - right_columns = set(_to_untagged(right_leaf).columns) - join_key_columns - if not right_columns: - continue - for needed_name in remaining_tables: - needed_leaf = self._get_leaf_table_by_name(self, needed_name) - if needed_leaf is None: - continue - needed_columns = set(_to_untagged(needed_leaf).columns) - join_key_columns - if right_columns & needed_columns: - return False - - return True - - def _effective_join_depth(self, parent_requirements: dict[str, set[str]] | None) -> int: - """Count the surviving left-spine joins after pruning.""" - augmented = self._augment_parent_requirements_for_pruning(parent_requirements) - left_depth = ( - self.left._effective_join_depth(augmented) - if isinstance(self.left, SemanticJoinOp) - else 0 - ) - if self._should_prune_right(augmented): - return left_depth - return left_depth + 1 - - @staticmethod - def _join_depth(op) -> int: - """Count nested left SemanticJoinOps to determine join depth.""" - depth = 0 - current = op - while isinstance(current, SemanticJoinOp): - depth += 1 - current = current.left - return depth - - @staticmethod - def _rname_for_depth(depth: int) -> str: - """Return the ``rname`` template for the given join depth. - - ibis uses ``{name}_right`` by default. When three or more tables - share a column name the second ``_right`` collides with the first. - We avoid this by appending the depth: ``_right``, ``_right2``, - ``_right3``, … - """ - return "{name}_right" if depth <= 1 else f"{{name}}_right{depth}" - - def to_untagged(self, parent_requirements: dict[str, set[str]] | None = None): - """Convert join to Ibis expression, pruning unnecessary joins. - - When *parent_requirements* is provided (a dict mapping table names to - sets of needed columns), right-side leaf tables that contribute no - required columns are skipped entirely. This avoids expensive joins - to dimension tables whose columns are never referenced by the query. - - Args: - parent_requirements: Optional mapping of ``{table_name: {col, …}}``. - When provided, joins to tables absent from this dict are elided. - - Returns: - Ibis join expression (potentially simplified). - """ - from .convert import _Resolver - - augmented_requirements = self._augment_parent_requirements_for_pruning(parent_requirements) - - # --- Join pruning: skip right-side tables not needed by the query --- - # Only prune join_one with LEFT join semantics. Inner joins act as - # a filter (excluding unmatched left rows) and must not be removed. - # join_many / join_cross affect row counts and are never pruned here - # (join_many is intercepted by the pre-aggregation path earlier). - if self._should_prune_right(augmented_requirements): - # Right side contributes nothing the parent needs — skip it. - if isinstance(self.left, SemanticJoinOp): - return self.left.to_untagged(parent_requirements=augmented_requirements) - return _to_untagged(self.left) + other_raw = _to_untagged(other_op) + can_filter = False + for pred in filters: + pred_fn = _unwrap(pred) + try: + pred_expr = pred_fn(other_raw) + other_raw = other_raw.filter(pred_expr) + can_filter = True + except Exception: + pass + if can_filter: + other_jk = join_tree_info.table_join_keys.get(other_name, set()) + shared = sorted( + jk & other_jk & set(raw_tbl.columns) & set(other_raw.columns) + ) + if shared: + key_bridge = other_raw.select( + [other_raw[c] for c in shared] + ).distinct() + preds = [raw_tbl[c] == key_bridge[c] for c in shared] + raw_tbl = raw_tbl.inner_join(key_bridge, preds).select(raw_tbl) + break - # Build both sides, passing requirements down for recursive pruning. - left_tbl = ( - _to_untagged(self.left) - if not isinstance(self.left, SemanticJoinOp) - else self.left.to_untagged(parent_requirements=augmented_requirements) - ) - right_tbl = ( - _to_untagged(self.right) - if not isinstance(self.right, SemanticJoinOp) - else self.right.to_untagged(parent_requirements=augmented_requirements) - ) + # Apply mutated group-by keys to this raw table so they can + # participate in the per-table grain computation. + if mutated_gb_keys: + for mutate_post in mutates: + for name, fn_wrapped in mutate_post.items(): + if name in mutated_gb_keys and name not in raw_tbl.columns: + try: + fn = _unwrap(fn_wrapped) + resolved = _resolve_expr(fn, raw_tbl) + raw_tbl = raw_tbl.mutate(**{name: resolved}) + except (TypeError, KeyError, AttributeError, + ibis.common.exceptions.IbisTypeError): + pass - # Rebind right side's DatabaseTable ops to use the same backend as - # the left side. from_ibis() creates a separate Backend object per - # call; xorq >=0.3.11 raises "Multiple backends found" unless all - # tables in a join share the same backend instance. - left_tbl, right_tbl = self._rebind_join_backends(left_tbl, right_tbl) + table_measures = _get_field_dict(table_op, "measures") + table_dims = _get_field_dict(table_op, "dimensions") + # Captured after mutated-key application above so the grain + # computation can see the newly-added columns. + raw_columns = set(raw_tbl.columns) - depth = ( - self._effective_join_depth(augmented_requirements) - if augmented_requirements is not None - else self._join_depth(self) - ) - rname = self._rname_for_depth(depth) + # Build agg expressions on the raw table + agg_exprs: dict = {} + for mname, _mfn in measures.items(): + short = mname.split(".", 1)[1] if "." in mname else mname + if short in table_measures: + expr = table_measures[short](raw_tbl) + # Decompose MEAN into SUM + COUNT for correct re-aggregation + if _is_mean_expr(expr): + mean_op = expr.op() + base_col = mean_op.arg.to_expr() + sum_col = f"_sum__{mname}" + count_col = f"_count__{mname}" + agg_exprs[sum_col] = base_col.sum() + agg_exprs[count_col] = base_col.count() + _decomposed_means[mname] = (sum_col, count_col) + elif _is_count_distinct_expr(expr): + # COUNT DISTINCT is immune to fan-out — defer past pre-agg + _deferred_count_distincts[mname] = ( + table_name, short, raw_tbl, table_measures[short], + ) + else: + _reagg_ops[mname] = _reagg_op_for_expr(expr) + agg_exprs[mname] = expr - if self.on is None: - return left_tbl.join(right_tbl, how=self.how, rname=rname) + if not agg_exprs: + continue - # Detect column name conflicts that cause ibis/xorq to raise - # "Ambiguous field reference" during predicate resolution. - conflicting = frozenset(left_tbl.columns) & frozenset(right_tbl.columns) + # --- Compute grain --- + if not group_by_cols: + # No group-by → scalar aggregate + pt = raw_tbl.aggregate(**agg_exprs) + # Recompute MEAN from SUM/COUNT for scalar results + for mname, (sc, cc) in _decomposed_means.items(): + if sc in pt.columns and cc in pt.columns: + pt = pt.mutate(**{mname: pt[sc] / pt[cc]}) + pt = pt.drop(sc, cc) + _preagg_results.append(pt) + continue - if not conflicting: - pred = self.on(_Resolver(left_tbl), _Resolver(right_tbl)) - return left_tbl.join(right_tbl, pred, how=self.how, rname=rname) + # a) group-by dims that live on this table + _local_dims = [] + has_cross_table_gb = False + for gb_key in group_by_cols: + # Mutated columns are already on raw_tbl (applied above) + if gb_key in mutated_gb_keys: + if gb_key in raw_columns and gb_key not in _local_dims: + _local_dims.append(gb_key) + elif "." in gb_key: + prefix, short = gb_key.split(".", 1) + if prefix == table_name and short in table_dims: + dim_fn = table_dims[short] + if callable(dim_fn): + dim_expr = dim_fn(raw_tbl) + col_name = dim_expr.get_name() + if col_name == short and col_name in raw_columns: + # Simple column reference — use directly + if col_name not in _local_dims: + _local_dims.append(col_name) + elif col_name in raw_columns or short not in raw_columns: + # Derived dimension — materialize on raw_tbl + raw_tbl = raw_tbl.mutate(**{short: dim_expr}) + raw_columns = set(raw_tbl.columns) + if short not in _local_dims: + _local_dims.append(short) + elif prefix != table_name: + has_cross_table_gb = True - # Temporarily rename conflicting left columns so the predicate - # can be resolved without ambiguity. - # ibis rename convention: {new_name: old_name} - rename_left = {f"{_BSL_JOIN_KEY_TMP_PREFIX}{c}": c for c in conflicting} - left_safe = left_tbl.rename(rename_left) + join_keys = join_tree_info.table_join_keys.get(table_name, set()) + available_jk = tuple(jk for jk in sorted(join_keys) if jk in raw_columns) - # Resolver that transparently maps original names → temp names, - # so predicates like ``lambda f, a: f.tail_num == a.tail_num`` - # still work even though left's ``tail_num`` was renamed. - orig_to_tmp = {c: f"{_BSL_JOIN_KEY_TMP_PREFIX}{c}" for c in conflicting} + # b) if none found, use join keys; if cross-table gb, augment with them + match (_local_dims, has_cross_table_gb): + case ([], _): + grain = available_jk + case (_, True): + grain = tuple( + dict.fromkeys( + _local_dims + [jk for jk in available_jk if jk not in _local_dims] + ) + ) + case _: + grain = tuple(_local_dims) - pred = self.on( - _RenamedResolver(left_safe, orig_to_tmp), - _Resolver(right_tbl), - ) - joined = left_safe.join(right_tbl, pred, how=self.how, rname=rname) + if grain: + _preagg_results.append( + raw_tbl.group_by([raw_tbl[c] for c in grain]).aggregate(**agg_exprs) + ) + else: + _preagg_results.append(raw_tbl.aggregate(**agg_exprs)) - # Restore final column names (ibis convention: {new: old}): - # - left temp columns → original names - # - right conflicting columns → depth-based rname suffix - rename_final = {c: f"{_BSL_JOIN_KEY_TMP_PREFIX}{c}" for c in conflicting} | { - rname.replace("{name}", c): c for c in conflicting - } + # Freeze mutable accumulators + preagg_results = tuple(_preagg_results) + decomposed_means = tuple(_decomposed_means.items()) + reagg_ops = tuple(_reagg_ops.items()) - return joined.rename(rename_final) + if not preagg_results and not _deferred_count_distincts: + if tbl is not None: + return tbl.aggregate({n: f(tbl) for n, f in plan.agg_specs.items()}) + raise ValueError("No aggregation results and full join unavailable") - @staticmethod - def _rebind_join_backends(left_tbl, right_tbl): - """Rebind DatabaseTable ops so both sides share a single backend. + # --- 5. Combine pre-agg results --- + result = None + if preagg_results: + if not group_by_cols: + # Cross-join all scalar results + result = preagg_results[0] + for pt in preagg_results[1:]: + result = result.cross_join(pt) + elif tbl is not None: + result = self._join_preagg_with_dim_bridge( + preagg_results, + plan, + tbl, + group_by_cols, + decomposed_means=decomposed_means, + reagg_ops=reagg_ops, + ) + else: + # Chasm fallback with group-by: build minimal dim bridge from raw tables + result = self._build_minimal_dim_bridge( + preagg_results, + plan, + group_by_cols, + join_tree_info, + merged_dimensions, + decomposed_means=decomposed_means, + reagg_ops=reagg_ops, + ) - When tables are individually wrapped via ``from_ibis()``, each gets - a distinct ``Backend`` object. xorq >=0.3.11 raises "Multiple - backends found" unless all tables in a join share the same instance. - Uses ``op.replace()`` (ibis graph rewriting) to swap out the - ``source`` field on every ``DatabaseTable`` node in the right tree. + # --- 5b. Compute deferred COUNT DISTINCT measures --- + # Optimisation: compute on the raw source table when all group-by + # columns are local (avoids scanning the fanned-out joined table). + # Fall back to the full joined table only for cross-table group-bys. + if _deferred_count_distincts: + cd_parts: list = [] + for mname, (_src_tbl_name, _short, src_raw, src_fn) in ( + _deferred_count_distincts.items() + ): + src_cols = set(src_raw.columns) + # Determine which group-by columns live on the raw source table + local_gb = [c for c in group_by_cols if c in src_cols] + if not group_by_cols or frozenset(local_gb) == frozenset(group_by_cols): + # All group-by cols are local (or scalar) — compute on raw table + cd_expr = src_fn(src_raw) + if local_gb: + cd_pt = src_raw.group_by( + [src_raw[c] for c in local_gb] + ).aggregate(**{mname: cd_expr}) + else: + cd_pt = src_raw.aggregate(**{mname: cd_expr}) + cd_parts.append(cd_pt) + else: + # Cross-table group-by — need the full joined table + if tbl is None: + raise ValueError( + "COUNT DISTINCT measures require the full joined table " + "for cross-table group-by but it is unavailable (chasm " + "fallback). Use join_one for reference tables or remove " + "count-distinct measures from this query." + ) + cd_expr = plan.agg_specs[mname](tbl) + gb_available = [c for c in group_by_cols if c in tbl.columns] + if gb_available: + cd_pt = tbl.group_by( + [tbl[c] for c in gb_available] + ).aggregate(**{mname: cd_expr}) + else: + cd_pt = tbl.aggregate(**{mname: cd_expr}) + cd_parts.append(cd_pt) - For plain ibis expressions (backends xorq doesn't wrap, e.g. - BigQuery), ``walk_nodes`` can't traverse the tree — fall back to - returning the inputs unchanged so ibis executes the join natively. - """ - try: - from ._xorq import relations as xorq_rel, walk_nodes - except ImportError: - return left_tbl, right_tbl + # Merge count-distinct parts into result + for cd_pt in cd_parts: + cd_meas = [c for c in cd_pt.columns if c in _deferred_count_distincts] + cd_grain = [c for c in cd_pt.columns if c not in _deferred_count_distincts] + if result is None: + result = cd_pt + elif cd_grain: + common = [c for c in cd_grain if c in result.columns] + if common: + preds = [result[c] == cd_pt[c] for c in common] + result = result.left_join(cd_pt, preds).select( + [result] + [cd_pt[m] for m in cd_meas] + ) + else: + result = result.cross_join(cd_pt) + else: + result = result.cross_join(cd_pt) - # Find a canonical backend from the left tree. - try: - db_tables = list(walk_nodes((xorq_rel.DatabaseTable,), left_tbl)) - except Exception: - # Plain ibis tree — no DatabaseTable nodes for xorq to walk. - logger.debug("walk_nodes failed on plain ibis left side", exc_info=True) - return left_tbl, right_tbl - canonical = db_tables[0].source if db_tables else None + # --- 6. Apply calc_specs --- + if plan.calc_specs: + result = self._apply_calc_specs(result, plan, tbl) - if canonical is None: - return left_tbl, right_tbl + # --- 7. Select requested columns --- + available = frozenset(result.columns) + select_cols = tuple( + dict.fromkeys( + c + for c in (*plan.group_by_cols, *plan.requested_measures, *plan.calc_specs.keys()) + if c in available + ) + ) + if select_cols: + result = result.select([result[c] for c in select_cols]) - def _recreate(op, _kwargs, **overrides): - kwargs = dict(zip(op.__argnames__, op.__args__, strict=False)) - if _kwargs: - kwargs.update(_kwargs) - kwargs.update(overrides) - return op.__recreate__(kwargs) + return result - def replacer(op, _kwargs): - if isinstance(op, xorq_rel.DatabaseTable) and op.source is not canonical: - return _recreate(op, _kwargs, source=canonical) - # Propagate rewritten children (e.g. SelfReference wrapping - # a replaced DatabaseTable). - if _kwargs: - return _recreate(op, _kwargs) - return op + def _to_untagged_with_deferred_joins( + self, + all_roots: list, + join_op: "SemanticJoinOp", + join_tree_info: "_JoinTreeInfo", + deferrable: list["_DeferrableJoin"], + filters: list | None = None, + ): + """Aggregate first, then LEFT JOIN deferred dimension tables. - new_left = left_tbl.op().replace(replacer).to_expr() - new_right = right_tbl.op().replace(replacer).to_expr() - return new_left, new_right + For ``join_one`` dimension lookups where the right table's PK matches + the join key and no measures from the right table are used, we can: + 1. Strip deferred tables from the join tree + 2. Add join keys to GROUP BY so they survive aggregation + 3. Aggregate on the core (non-deferred) tables + 4. LEFT JOIN deferred dimension tables onto the aggregated result + """ + filters = filters or [] + deferred_names = {d.table_name for d in deferrable} - def execute(self): - return _rebind_to_canonical_backend(self.to_untagged()).execute() + merged_dimensions = _get_merged_fields(all_roots, "dimensions") + merged_base_measures = _get_merged_fields(all_roots, "measures") + merged_calc_measures = _get_merged_fields(all_roots, "calc_measures") - def compile(self, **kwargs): - return _rebind_to_canonical_backend(self.to_untagged()).compile(**kwargs) + # --- 1. Rebuild join tree without deferred tables --- + def strip_deferred(node): + """Remove deferred right-side tables from the join tree.""" + if not isinstance(node, SemanticJoinOp): + return node + # Recurse left + new_left = strip_deferred(node.left) + # If right side is a deferred table, return just the left + if ( + isinstance(node.right, SemanticTableOp) + and node.right.name in deferred_names + ): + return new_left + return SemanticJoinOp( + left=new_left, + right=node.right, + how=node.how, + on=node.on, + cardinality=node.cardinality, + ) - def sql(self, **kwargs): - return ibis.to_sql(_rebind_to_canonical_backend(self.to_untagged()), **kwargs) + core_join = strip_deferred(join_op) - def __getitem__(self, key): - dims_dict = self.get_dimensions() - if key in dims_dict: - return dims_dict[key] + # --- 2. Build core table and aggregate --- + if isinstance(core_join, SemanticJoinOp): + core_tbl = core_join.to_untagged(parent_requirements=None) + else: + core_tbl = _to_untagged(core_join) - meas_dict = self.get_measures() - if key in meas_dict: - return meas_dict[key] + # Resolve group-by dimensions on the core table + core_group_keys = [] + for k in self.keys: + if k in deferred_names or any(k.startswith(f"{dn}.") for dn in deferred_names): + # Skip deferred table's own dimensions from group-by + # (they'll be added via post-agg join) + continue + core_group_keys.append(k) - calc_meas_dict = self.get_calculated_measures() - if key in calc_meas_dict: - return calc_meas_dict[key] + # Add left-side join keys to group-by so they survive aggregation + for d in deferrable: + for jk in d.join_keys_left: + if jk not in core_group_keys and jk in core_tbl.columns: + core_group_keys.append(jk) - raise KeyError( - f"'{key}' not found in dimensions, measures, or calculated measures", + # Resolve dimensions for group-by on the core table + core_tbl = _mutate_dimensions_with_dependencies( + core_tbl, + [k for k in core_group_keys if k in merged_dimensions], + merged_dimensions, ) - def pipe(self, func, *args, **kwargs): - return func(self, *args, **kwargs) + # Apply filters + if filters: + from ..convert import _Resolver - def as_table(self) -> SemanticTable: - """Convert to SemanticTable, preserving merged metadata from both sides.""" - return _semantic_table( - table=self.to_untagged(), - dimensions=self.get_dimensions(), - measures=self.get_measures(), - calc_measures=self.get_calculated_measures(), + for pred in filters: + pred_fn = _unwrap(pred) + try: + resolver = _Resolver(core_tbl, merged_dimensions) + pred_expr = _resolve_expr(pred_fn, resolver) + core_tbl = core_tbl.filter(pred_expr) + except Exception: + pass + + # Build aggregation expressions + scope = MeasureScope( + _tbl=core_tbl, + _known=list(merged_base_measures.keys()) + list(merged_calc_measures.keys()), + ) + plan = _build_aggregation_plan( + aggs=self.aggs, + keys=tuple(core_group_keys), + scope=scope, + is_post_agg=False, + merged_base_measures=merged_base_measures, + merged_calc_measures=merged_calc_measures, + tbl=core_tbl, ) + # Execute aggregation on core table + if plan.group_by_cols: + gb_exprs = [] + for col in plan.group_by_cols: + if col in core_tbl.columns: + gb_exprs.append(core_tbl[col]) + elif col in merged_dimensions: + dim_fn = merged_dimensions[col] + gb_exprs.append(dim_fn(core_tbl).name(col)) -class SemanticOrderByOp(Relation): - source: Relation - keys: tuple[ - str | ir.Value | Callable, - ..., - ] # Transformed to tuple[str | _CallableWrapper, ...] in __init__ + agg_exprs = {name: fn(core_tbl) for name, fn in plan.agg_specs.items()} + result = core_tbl.group_by(gb_exprs).aggregate(**agg_exprs) + else: + agg_exprs = {name: fn(core_tbl) for name, fn in plan.agg_specs.items()} + result = core_tbl.aggregate(**agg_exprs) - def __init__(self, source: Relation, keys: Iterable[str | ir.Value | Callable]) -> None: - def wrap_key(k): - return k if isinstance(k, str | _CallableWrapper) else _ensure_wrapped(k) + # Handle calculated measures + if plan.calc_specs: + from ..compile_all import compile_calc_measures - super().__init__( - source=Relation.__coerce__(source), - keys=tuple(wrap_key(k) for k in keys), - ) + result = compile_calc_measures(result, plan.calc_specs) - def __repr__(self) -> str: - return _semantic_repr(self) + # --- 3. LEFT JOIN deferred dimension tables --- + for d in deferrable: + dim_tbl = _to_untagged(d.table_op) - @property - def values(self) -> FrozenOrderedDict[str, Any]: - return self.source.values + # Reuse the original join predicate to preserve key pairing + # (avoids mismatch from independently sorting left/right keys) + join_preds = d.on_predicate(result, dim_tbl) - @property - def schema(self) -> Schema: - return self.source.schema + # Compute deferred dimension columns on the dimension table. + # For direct columns (e.g., cc_name) this is a no-op rename. + # For derived expressions (e.g., cc_name.upper()) this materializes + # the computed column so it's available after the join. + dim_cols_to_add = [] + right_dims = d.table_op.get_dimensions() + for dim_name in d.deferred_dims: + short = dim_name.split(".", 1)[1] if "." in dim_name else dim_name + if short in right_dims: + dim_fn = right_dims[short] + if callable(dim_fn): + try: + expr = dim_fn(dim_tbl) + col_name = expr.get_name() + if col_name in dim_tbl.columns: + # Direct column — use as-is + dim_cols_to_add.append((dim_name, col_name)) + else: + # Derived expression — mutate onto dim table + # Use a temp name to avoid collisions + temp_name = f"__deferred_{short}" + dim_tbl = dim_tbl.mutate(**{temp_name: expr}) + dim_cols_to_add.append((dim_name, temp_name)) + except Exception: + pass - def to_untagged(self): - tbl = _to_untagged(self.source) + if dim_cols_to_add: + # Perform the LEFT JOIN + joined = result.left_join(dim_tbl, join_preds) + # Select: all from result + deferred dim columns renamed to prefixed names + select_exprs = [result] + for prefixed_name, raw_col in dim_cols_to_add: + select_exprs.append(dim_tbl[raw_col].name(prefixed_name)) + result = joined.select(select_exprs) - def resolve_order_key(key): - if isinstance(key, str): - return tbl[key] if key in tbl.columns else getattr(tbl, key, key) - elif isinstance(key, _CallableWrapper): - unwrapped = _unwrap(key) - return _resolve_expr(unwrapped, tbl) - return key + # --- 4. Select only originally requested columns --- + original_cols = list(self.keys) + list(self.aggs.keys()) + # Add deferred dim names (they were requested) + for d in deferrable: + for dd in d.deferred_dims: + if dd not in original_cols: + original_cols.append(dd) + # Also keep calc measure names + if plan.calc_specs: + for cm in plan.calc_specs: + if cm not in original_cols: + original_cols.append(cm) - return tbl.order_by([resolve_order_key(key) for key in self.keys]) + available = frozenset(result.columns) + select_cols = [c for c in original_cols if c in available] + if select_cols and set(select_cols) != available: + result = result.select([result[c] for c in select_cols]) - def get_dimensions(self) -> Mapping[str, Dimension]: - """Get dictionary of dimensions from source.""" - return self.source.get_dimensions() + return result - def get_measures(self) -> Mapping[str, Measure]: - """Get dictionary of measures from source.""" - return self.source.get_measures() + # -- helpers for _to_untagged_with_preagg -------------------------------- - def get_calculated_measures(self) -> Mapping[str, Any]: - """Get dictionary of calculated measures from source.""" - return self.source.get_calculated_measures() + @staticmethod + def _join_preagg_with_dim_bridge( + preagg_results, plan, tbl, group_by_cols, decomposed_means=(), reagg_ops=() + ): + """Join pre-aggregated tables using per-table dimension bridges. + ``decomposed_means`` and ``reagg_ops`` are tuples of (key, value) pairs. + """ + from ..compile_all import _join_tables -class SemanticLimitOp(Relation): - source: Relation - n: int - offset: int + reagg_map = dict(reagg_ops) + # Include decomposed auxiliary columns in measure names + aux_cols = frozenset(c for _, (sc, cc) in decomposed_means for c in (sc, cc)) + measure_names = ( + frozenset(plan.agg_specs.keys()) | frozenset(plan.calc_specs.keys()) | aux_cols + ) + gb_set = frozenset(group_by_cols) - def __init__(self, source: Relation, n: int, offset: int = 0) -> None: - if n <= 0: - raise ValueError(f"limit must be positive, got {n}") - if offset < 0: - raise ValueError(f"offset must be non-negative, got {offset}") - super().__init__(source=Relation.__coerce__(source), n=n, offset=offset) + def _rejoin_one(pt): + pt_grain = tuple(c for c in pt.columns if c not in measure_names) + pt_meas = tuple(c for c in pt.columns if c in measure_names) - def __repr__(self) -> str: - return _semantic_repr(self) + if gb_set <= frozenset(pt_grain): + # Already has group-by columns — re-aggregate if over-grouped + if frozenset(pt_grain) != gb_set: + re_aggs = {m: _build_reagg(pt[m], reagg_map.get(m, "sum")) for m in pt_meas} + return pt.group_by([pt[c] for c in group_by_cols]).aggregate(**re_aggs) + return pt - @property - def values(self) -> FrozenOrderedDict[str, Any]: - return self.source.values + if not pt_grain: + return pt - @property - def schema(self) -> Schema: - return self.source.schema + # Build a per-table dim bridge with ONLY this table's grain cols + bridge_cols = tuple( + dict.fromkeys(c for c in (*group_by_cols, *pt_grain) if c in tbl.columns) + ) + if not bridge_cols: + return pt - def to_untagged(self): - tbl = _to_untagged(self.source) - return tbl.limit(self.n) if self.offset == 0 else tbl.limit(self.n, offset=self.offset) + dim_bridge = tbl.select([tbl[c] for c in bridge_cols]).distinct() + common = tuple(c for c in pt_grain if c in dim_bridge.columns) + if not common: + return pt - def get_dimensions(self) -> Mapping[str, Dimension]: - """Get dictionary of dimensions from source.""" - return self.source.get_dimensions() + preds = [dim_bridge[c] == pt[c] for c in common] + joined_pt = dim_bridge.left_join(pt, preds).select( + [dim_bridge] + [pt[c] for c in pt_meas] + ) + gb_avail = tuple(c for c in group_by_cols if c in joined_pt.columns) + if gb_avail: + re_aggs = { + m: _build_reagg(joined_pt[m], reagg_map.get(m, "sum")) + for m in pt_meas + if m in joined_pt.columns + } + if re_aggs: + joined_pt = joined_pt.group_by([joined_pt[c] for c in gb_avail]).aggregate( + **re_aggs + ) + return joined_pt - def get_measures(self) -> Mapping[str, Measure]: - """Get dictionary of measures from source.""" - return self.source.get_measures() + rejoined = tuple(_rejoin_one(pt) for pt in preagg_results) + result = _join_tables(group_by_cols, list(rejoined)) - def get_calculated_measures(self) -> Mapping[str, Any]: - """Get dictionary of calculated measures from source.""" - return self.source.get_calculated_measures() + # Recompute MEAN from decomposed SUM/COUNT after combining + for mname, (sc, cc) in decomposed_means: + if sc in result.columns and cc in result.columns: + result = result.mutate(**{mname: result[sc] / result[cc]}) + result = result.drop(sc, cc) + return result -def _get_field_type_str(field_type: Any) -> str: - return ( - "string" - if field_type.is_string() - else "number" - if field_type.is_numeric() - else "date" - if field_type.is_temporal() - else str(field_type) - ) + @staticmethod + def _build_minimal_dim_bridge( + preagg_results, + plan, + group_by_cols, + join_tree_info, + merged_dimensions, + decomposed_means=(), + reagg_ops=(), + ): + """Build dim bridges from raw tables when full join is unavailable. + When the full ibis join fails (e.g. column collisions with 3+ + ``join_many`` arms sharing the same key), we build per-table dimension + bridges using only the raw single-table data already captured in + *join_tree_info*. This avoids the ibis collision entirely because we + never join more than two tables at once. -def _get_weight_expr( - base_tbl: Any, - by_measure: str | None, - all_roots: list, - is_string: bool, -) -> Any: - from ._xorq import api as xo + ``decomposed_means`` and ``reagg_ops`` are tuples of (key, value) pairs. + """ + from ..compile_all import _join_tables - if not by_measure: - return xo._.count() + reagg_map = dict(reagg_ops) + aux_cols = frozenset(c for _, (sc, cc) in decomposed_means for c in (sc, cc)) + measure_names = ( + frozenset(plan.agg_specs.keys()) | frozenset(plan.calc_specs.keys()) | aux_cols + ) + gb_set = frozenset(group_by_cols) - merged_measures = _get_merged_fields(all_roots, "measures") - return ( - merged_measures[by_measure](base_tbl) if by_measure in merged_measures else xo._.count() - ) + def _bridge_one_preagg(pt): + pt_grain = tuple(c for c in pt.columns if c not in measure_names) + pt_meas = tuple(c for c in pt.columns if c in measure_names) + # (a) Pre-agg already carries all group-by cols — re-aggregate. + if gb_set <= frozenset(pt_grain): + if frozenset(pt_grain) != gb_set: + re_aggs = {m: _build_reagg(pt[m], reagg_map.get(m, "sum")) for m in pt_meas} + return pt.group_by([pt[c] for c in group_by_cols]).aggregate(**re_aggs) + return pt -def _build_string_index_fragment( - base_tbl: Any, - field_expr: Any, - field_name: str, - field_path: str, - type_str: str, - weight_expr: Any, -) -> Any: - from ._xorq import api as xo + # (b) Scalar pre-agg — nothing to bridge. + if not pt_grain: + return pt - return ( - base_tbl.group_by(field_expr.name("value")) - .aggregate(weight=weight_expr) - .select( - fieldName=xo.literal(field_name.split(".")[-1]), - fieldPath=xo.literal(field_path), - fieldType=xo.literal(type_str), - fieldValue=xo._["value"].cast("string"), - weight=xo._["weight"], - ) - ) + # (c) Bridge missing group-by dims from raw tables. + bridged = pt + for gb_col in group_by_cols: + if gb_col in bridged.columns: + continue + bridged = _attach_dim_column( + bridged, + gb_col, + measure_names, + join_tree_info, + merged_dimensions, + ) + # Re-aggregate onto the requested group-by granularity. + gb_avail = tuple(c for c in group_by_cols if c in bridged.columns) + re_aggs = { + m: _build_reagg(bridged[m], reagg_map.get(m, "sum")) + for m in pt_meas + if m in bridged.columns + } + match (gb_avail, bool(re_aggs)): + case ((), _) | (_, False): + return bridged + case _: + return bridged.group_by([bridged[c] for c in gb_avail]).aggregate(**re_aggs) -def _build_numeric_index_fragment( - base_tbl: Any, - field_expr: Any, - field_name: str, - field_path: str, - type_str: str, - weight_expr: Any, -) -> Any: - from ._xorq import api as xo + rejoined = tuple(_bridge_one_preagg(pt) for pt in preagg_results) + result = _join_tables(group_by_cols, list(rejoined)) - return ( - base_tbl.select(field_expr.name("value")) - .filter(xo._["value"].notnull()) - .aggregate( - min_val=xo._["value"].min(), - max_val=xo._["value"].max(), - weight=weight_expr, - ) - .select( - fieldName=xo.literal(field_name.split(".")[-1]), - fieldPath=xo.literal(field_path), - fieldType=xo.literal(type_str), - fieldValue=( - xo._["min_val"].cast("string") + " to " + xo._["max_val"].cast("string") - ), - weight=xo._["weight"], - ) - ) + # Recompute MEAN from decomposed SUM/COUNT after combining + for mname, (sc, cc) in decomposed_means: + if sc in result.columns and cc in result.columns: + result = result.mutate(**{mname: result[sc] / result[cc]}) + result = result.drop(sc, cc) + return result -def _resolve_selector( - selector: str | list[str] | Callable | None, - base_tbl: ir.Table, -) -> tuple[str, ...]: - if selector is None: - return tuple(base_tbl.columns) - try: - selected = base_tbl.select(selector) - return tuple(selected.columns) - except Exception: - return [] + @staticmethod + def _apply_calc_specs(result, plan, tbl): + """Apply calculated measure specs (ratios, percent-of-total, etc.).""" + from ..compile_all import _collect_all_refs, _compile_formula + needed_totals: set[str] = set() + for ast in plan.calc_specs.values(): + _collect_all_refs(ast, needed_totals) -def _get_fields_to_index( - selector: str | list[str] | Callable | None, - merged_dimensions: dict, - base_tbl: ir.Table, -) -> tuple[str, ...]: - if selector is None: - selector = s.all() + if needed_totals: + totals_aggs = {ref: result[ref].sum() for ref in needed_totals if ref in result.columns} + all_tbl = result.aggregate(**totals_aggs) if totals_aggs else None + else: + all_tbl = None - raw_fields = _resolve_selector(selector, base_tbl) + out = result.cross_join(all_tbl) if all_tbl is not None else result + calc_cols = { + name: _compile_formula(ast, out, all_tbl, tbl if tbl is not None else out) + for name, ast in plan.calc_specs.items() + } + return out.mutate(**calc_cols) - if not raw_fields: - result = list(merged_dimensions.keys()) - result.extend(col for col in base_tbl.columns if col not in result) - else: - result = [col for col in raw_fields if col in merged_dimensions or col in base_tbl.columns] - return result +from ._mutate_unnest import SemanticMutateOp, SemanticUnnestOp # noqa: E402, F401 -class SemanticIndexOp(Relation): - source: Relation - selector: str | list[str] | tuple[str, ...] | Callable | None - by: str | None = None - sample: int | None = None +class SemanticJoinOp(Relation): + left: Relation + right: Relation + how: str + on: ( + Callable[[Any, Any], Any] | None + ) # Returns BooleanValue from either ibis or xorq.vendor.ibis + cardinality: str # "one", "many", or "cross" def __init__( self, - source: Relation, - selector: str | list[str] | tuple[str, ...] | Callable | None = None, - by: str | None = None, - sample: int | None = None, + left: Relation, + right: Relation, + how: str = "left", + on: Callable[[Any, Any], Any] | None = None, + cardinality: str = "one", ) -> None: - # Validate sample parameter - if sample is not None and sample <= 0: - raise ValueError(f"sample must be positive, got {sample}") - - # Validate 'by' measure exists if provided - if by is not None: - all_roots = _find_all_root_models(source) - if all_roots: - merged_measures = _get_merged_fields(all_roots, "measures") - if by not in merged_measures: - available = list(merged_measures.keys()) - raise KeyError( - f"Unknown measure '{by}' for weight calculation. " - f"Available measures: {', '.join(available) or 'none'}", - ) - - # Convert selector to tuple if it's a list (Ibis requires hashable types) - hashable_selector = tuple(selector) if isinstance(selector, list) else selector - super().__init__( - source=Relation.__coerce__(source), - selector=hashable_selector, - by=by, - sample=sample, + left=Relation.__coerce__(left), + right=Relation.__coerce__(right), + how=how, + on=on, + cardinality=cardinality, ) def __repr__(self) -> str: @@ -4402,929 +2034,875 @@ def __repr__(self) -> str: @property def values(self) -> FrozenOrderedDict[str, Any]: - from ._xorq import api as xo - - return FrozenOrderedDict( - { - "fieldName": xo.literal("").op(), - "fieldPath": xo.literal("").op(), - "fieldType": xo.literal("").op(), - "fieldValue": xo.literal("").op(), - "weight": xo.literal(0).op(), - }, - ) - - @property - def schema(self) -> Schema: - return Schema( - { - "fieldName": "string", - "fieldPath": "string", - "fieldType": "string", - "fieldValue": "string", - "weight": "int64", - }, - ) + vals: dict[str, Any] = {} + vals.update(self.left.values) + vals.update(self.right.values) + return FrozenOrderedDict(vals) @property - def keys(self) -> tuple[str, ...]: - return ("fieldValue", "fieldName", "fieldPath", "fieldType") + def schema(self): + """Get schema of semantic table. - @property - def aggs(self) -> dict[str, Any]: - return {"weight": lambda t: t.weight} + Uses runtime imports to handle both regular ibis and xorq (which vendors ibis). + Converts dtypes to strings to allow Schema to parse them into the correct dtype objects. + """ + fields_dict = {name: str(v.dtype) for name, v in self.values.items()} + return _make_schema(fields_dict) - def to_untagged(self): - all_roots = _find_all_root_models(self.source) - base_tbl = ( - _to_untagged(self.source).limit(self.sample) - if self.sample - else _to_untagged(self.source) + def get_dimensions(self) -> Mapping[str, Dimension]: + """Get dictionary of dimensions with metadata.""" + all_roots = _find_all_root_models(self) + return _merge_fields_with_prefixing( + all_roots, + lambda r: _get_field_dict(r, "dimensions"), + source=self, # Pass self to extract join keys ) - merged_dimensions = _get_merged_fields(all_roots, "dimensions") - fields_to_index = _get_fields_to_index( - self.selector, - merged_dimensions, - base_tbl, + def get_measures(self) -> Mapping[str, Measure]: + """Get dictionary of base measures with metadata.""" + all_roots = _find_all_root_models(self) + return _merge_fields_with_prefixing( + all_roots, + lambda r: _get_field_dict(r, "measures"), + source=self, ) - if not fields_to_index: - from ._xorq import api as xo - - return xo.memtable( - { - "fieldName": [], - "fieldPath": [], - "fieldType": [], - "fieldValue": [], - "weight": [], - }, - ) - - def build_fragment(field_name: str) -> Any: - field_expr = ( - merged_dimensions[field_name](base_tbl) - if field_name in merged_dimensions - else base_tbl[field_name] - ) - field_type = field_expr.type() - type_str = _get_field_type_str(field_type) - weight_expr = _get_weight_expr( - base_tbl, - self.by, - all_roots, - field_type.is_string(), - ) - - return ( - _build_string_index_fragment( - base_tbl, - field_expr, - field_name, - field_name, - type_str, - weight_expr, - ) - if field_type.is_string() or not field_type.is_numeric() - else _build_numeric_index_fragment( - base_tbl, - field_expr, - field_name, - field_name, - type_str, - weight_expr, - ) - ) + def get_calculated_measures(self) -> Mapping[str, Any]: + """Get dictionary of calculated measures with metadata.""" + all_roots = _find_all_root_models(self) + return _merge_fields_with_prefixing( + all_roots, + lambda r: _get_field_dict(r, "calc_measures"), + source=self, + ) - fragments = [build_fragment(f) for f in fields_to_index] - return reduce(lambda acc, frag: acc.union(frag), fragments[1:], fragments[0]) + @property + def dimensions(self) -> tuple[str, ...]: + """Get tuple of dimension names.""" + return tuple(self.get_dimensions().keys()) - def filter(self, predicate: Callable) -> SemanticFilter: - from .expr import SemanticFilter + @property + def _dims(self) -> dict[str, Dimension]: + return dict(self.get_dimensions()) - return SemanticFilter(source=self, predicate=predicate) + @property + def _base_measures(self) -> dict[str, Measure]: + return dict(self.get_measures()) - def order_by(self, *keys: str | ir.Value | Callable) -> SemanticOrderBy: - from .expr import SemanticOrderBy + @property + def _calc_measures(self) -> dict[str, Any]: + return dict(self.get_calculated_measures()) - return SemanticOrderBy(source=self, keys=keys) + @property + def calc_measures(self) -> dict[str, Any]: + """Get calculated measures as dict (for consistency with SemanticModel).""" + return dict(self.get_calculated_measures()) - def limit(self, n: int, offset: int = 0) -> SemanticLimit: - from .expr import SemanticLimit + @property + def measures(self) -> tuple[str, ...]: + return tuple(self.get_measures().keys()) + tuple( + self.get_calculated_measures().keys(), + ) - return SemanticLimit(source=self, n=n, offset=offset) + @property + def json_definition(self) -> Mapping[str, Any]: + return _build_json_definition(self.get_dimensions(), self.get_measures(), None) - def execute(self): - return _rebind_to_canonical_backend(self.to_untagged()).execute() + @property + def name(self) -> str | None: + return None - def as_expr(self): - """Return self as expression.""" - return self + @property + def description(self) -> str | None: + """Get description for joined model by combining root model descriptions.""" + roots = _find_all_root_models(self) + base_descriptions = [] + for root in roots: + root_name = getattr(root, "name", None) or "unnamed" + root_desc = getattr(root, "description", None) + if root_desc: + base_descriptions.append(f"{root_name} ({root_desc})") + else: + base_descriptions.append(root_name) + if base_descriptions: + return "Joined model combining: " + ", ".join(base_descriptions) + return None - def compile(self, **kwargs): - return self.to_untagged().compile(**kwargs) + @property + def table(self): + return self.to_untagged() - def sql(self, **kwargs): - return ibis.to_sql(self.to_untagged(), **kwargs) + def query( + self, + dimensions: Sequence[str] | None = None, + measures: Sequence[str] | None = None, + filters: list | None = None, + order_by: Sequence[tuple[str, str]] | None = None, + limit: int | None = None, + time_grain: str | None = None, + time_range: dict[str, str] | None = None, + having: list | None = None, + ): + from ..query import query as build_query - def __getitem__(self, key): - return self.to_untagged()[key] + return build_query( + semantic_table=self, + dimensions=dimensions, + measures=measures, + filters=filters, + order_by=order_by, + limit=limit, + time_grain=time_grain, + time_range=time_range, + having=having, + ) - def pipe(self, func, *args, **kwargs): - return func(self, *args, **kwargs) + def with_dimensions(self, **dims) -> SemanticTable: + return _semantic_table( + table=self.to_untagged(), + dimensions={**self.get_dimensions(), **dims}, + measures=self.get_measures(), + calc_measures=self.get_calculated_measures(), + name=None, + ) + def with_measures(self, **meas) -> SemanticTable: + joined_tbl = self.to_untagged() + all_known = ( + list(self.get_measures().keys()) + + list(self.get_calculated_measures().keys()) + + list(meas.keys()) + ) + scope = MeasureScope(_tbl=joined_tbl, _known=all_known) -def _find_root_model(node: Any) -> SemanticTableOp | None: - """Find root SemanticTableOp in the operation tree.""" - cur = node - while cur is not None: - if isinstance(cur, SemanticTableOp): - return cur - parent = getattr(cur, "source", None) - cur = parent - return None + new_base, new_calc = ( + dict(self.get_measures()), + dict(self.get_calculated_measures()), + ) + for name, fn_or_expr in meas.items(): + kind, value = _classify_measure(fn_or_expr, scope, name) + (new_calc if kind == "calc" else new_base)[name] = value + return _semantic_table( + table=joined_tbl, + dimensions=self.get_dimensions(), + measures=new_base, + calc_measures=new_calc, + name=None, + _source_join=self, # Pass join reference for projection pushdown + ) -def _find_all_root_models(node: Any) -> tuple[SemanticTableOp, ...]: - """Find all root SemanticTableOps in the operation tree (handles joins with multiple roots).""" - if isinstance(node, SemanticTableOp): - return [node] + def group_by(self, *keys: str) -> SemanticGroupBy: + from ..expr import SemanticGroupBy - roots = [] + return SemanticGroupBy(source=self, keys=keys) - if hasattr(node, "left") and hasattr(node, "right"): - roots.extend(_find_all_root_models(node.left)) - roots.extend(_find_all_root_models(node.right)) - elif hasattr(node, "source") and node.source is not None: - roots.extend(_find_all_root_models(node.source)) + def filter(self, predicate: Callable) -> SemanticFilter: + from ..expr import SemanticFilter - return roots + return SemanticFilter(source=self, predicate=predicate) + def join_one( + self, + other: SemanticTable, + on: Callable[[Any, Any], ir.BooleanValue], + how: str = "left", + ): + """Join with one-to-one relationship semantics (left outer join).""" + from ..expr import SemanticJoin -def _dimension_only_source_table( - keys: tuple[str, ...], - all_roots: Sequence[SemanticTableOp], - filters: tuple, -) -> tuple[SemanticTableOp, list[str], tuple] | None: - """Check if a dimension-only query can be routed to a single source table. + return SemanticJoin( + left=self, + right=other.op(), + on=on, + how=how, + cardinality="one", + ) - When all requested dimension keys share a single table prefix and that - prefix maps to a root model whose dimensions cover every key, we can - bypass the join and query the dimension table directly. This ensures - dimension members with no matching fact rows are still returned. + def join_many( + self, + other: SemanticTable, + on: Callable[[Any, Any], ir.BooleanValue], + how: str = "left", + ): + """Join with one-to-many relationship semantics.""" + from ..expr import SemanticJoin - *filters* are the ``_CallableWrapper`` predicates collected between the - aggregate and the underlying join. Filters whose column references all - belong to the target table are forwarded; if any filter references columns - outside the target table the shortcut is disabled. + return SemanticJoin( + left=self, + right=other.op(), + on=on, + how=how, + cardinality="many", + ) - Returns ``(root_op, unprefixed_keys, applicable_filters)`` or ``None``. - """ - if not keys: - return None + def join_cross(self, other: SemanticTable): + """Cross join (Cartesian product) with another semantic model.""" + from ..expr import SemanticJoin - prefixes: set[str] = set() - unprefixed: list[str] = [] - for key in keys: - if "." not in key: - return None # Non-prefixed key — can't determine source - prefix, name = key.split(".", 1) - prefixes.add(prefix) - unprefixed.append(name) - - if len(prefixes) != 1: - return None # Keys span multiple tables - - target_prefix = next(iter(prefixes)) - - for root in all_roots: - if root.name == target_prefix: - root_dims = root.get_dimensions() - if all(k in root_dims for k in unprefixed): - # Validate that every filter only touches columns present - # on the target dimension table. If any filter references - # columns from other tables we cannot use the shortcut. - if filters: - tbl = _to_untagged(root) - tbl_cols = frozenset(tbl.columns) | frozenset(root_dims) - for flt in filters: - fn = _unwrap(flt) if hasattr(flt, "unwrap") else flt - extraction = _extract_columns_from_callable(fn, tbl) - if extraction.extraction_failed: - return None # Can't determine — bail out - if not extraction.columns <= tbl_cols: - return None # References columns outside target - return root, unprefixed, filters + return SemanticJoin( + left=self, + right=other.op(), + on=None, + how="cross", + cardinality="cross", + ) - return None + def join(self, *args, **kwargs): + """Deprecated: Use join_one(), join_many(), or join_cross() instead.""" + raise TypeError(_JOIN_REMOVED_MESSAGE) + def index( + self, + selector: str | list[str] | Callable | None = None, + by: str | None = None, + sample: int | None = None, + ) -> SemanticIndexOp: + """Create an index for search/discovery. -def _build_join_depth_map(node: Any) -> dict[str, int]: - """Map each leaf table name to its actual ibis rname depth. + Supports ibis selectors (s.all(), s.cols(), etc.). + """ - ``SemanticJoinOp.to_untagged`` calls ``_join_depth`` to determine the - rname suffix for each join level. ``_join_depth`` counts the number - of ``SemanticJoinOp`` ancestors on the *left* spine. The right child - at depth *d* gets ``rname = _rname_for_depth(d)``. + # Handle ibis selectors + processed_selector = selector + if selector is not None and "ibis.selectors" in str(type(selector).__module__): + # Handle s.all() - select all columns + if type(selector).__name__ == "AllColumns": + processed_selector = None + # Handle s.cols() - select specific columns + elif type(selector).__name__ == "Cols": + # Extract column names from the Cols selector + processed_selector = sorted(selector.names) + # For other selectors, keep as-is + else: + processed_selector = selector - For nested subtrees on the right side of a join, ibis applies the - inner subtree's rname independently. So ``aircraft_models`` at inner - depth 1 gets ``_right``, not ``_right3`` even if the outer depth is 3. + return SemanticIndexOp(source=self, selector=processed_selector, by=by, sample=sample) - This function mirrors ``_join_depth`` logic: walk down the left spine, - recording the right child's depth at each level. If the right child is - itself a join tree, recurse to get inner depths for its leaves. - """ - depth_map: dict[str, int] = {} - - def _record_leaf(n, depth: int): - """Record a leaf table at the given depth.""" - if isinstance(n, SemanticTableOp): - name = n.name - if name and name not in depth_map: - depth_map[name] = depth - - def _walk_join_spine(n): - """Walk the left spine of a join tree, recording depths.""" - if not isinstance(n, SemanticJoinOp): - # Leftmost leaf: depth 0 (root, never renamed) - _record_leaf(n, 0) - return + def _collect_leaf_table_names(self) -> set[str]: + """Collect names of all leaf (base) tables in this join tree.""" + tables = set() - depth = SemanticJoinOp._join_depth(n) - # The right child is joined at this depth - right = n.right - if isinstance(right, SemanticJoinOp): - # Right is a subtree — its leaves get inner depths - inner_map = _build_join_depth_map(right) - for tname, idepth in inner_map.items(): - if tname not in depth_map: - if idepth == 0: - # Leftmost leaf of subtree sits at the outer depth - # (it receives the outer rname suffix if conflicting) - depth_map[tname] = depth - else: - # Inner leaves keep their inner depth (inner rname) - depth_map[tname] = idepth + if isinstance(self.left, SemanticJoinOp): + tables |= self.left._collect_leaf_table_names() else: - _record_leaf(right, depth) - - # Recurse down the left spine - _walk_join_spine(n.left) - - _walk_join_spine(node) - return depth_map + left_name = getattr(self.left, "name", None) + if left_name: + tables.add(left_name) + if isinstance(self.right, SemanticJoinOp): + tables |= self.right._collect_leaf_table_names() + else: + right_name = getattr(self.right, "name", None) + if right_name: + tables.add(right_name) -def _update_measure_refs_in_calc(expr, prefix_map: dict[str, str]): - """ - Recursively update MeasureRef names in a calculated measure expression. + return tables - Args: - expr: A MeasureExpr (MeasureRef, AllOf, BinOp, MethodCall, or literal) - prefix_map: Mapping from old name to new prefixed name + @property + def required_columns(self): + """ + Column requirements for projection pushdown. - Returns: - Updated expression with prefixed MeasureRef names - """ - if isinstance(expr, MeasureRef): - # Update the measure reference name if it's in the map - new_name = prefix_map.get(expr.name, expr.name) - return MeasureRef(new_name) - elif isinstance(expr, AllOf): - # Update the inner MeasureRef - updated_ref = _update_measure_refs_in_calc(expr.ref, prefix_map) - return AllOf(updated_ref) - elif isinstance(expr, MethodCall): - updated_receiver = _update_measure_refs_in_calc(expr.receiver, prefix_map) - return MethodCall( - receiver=updated_receiver, - method=expr.method, - args=expr.args, - kwargs=expr.kwargs, - ) - elif isinstance(expr, BinOp): - # Recursively update left and right - updated_left = _update_measure_refs_in_calc(expr.left, prefix_map) - updated_right = _update_measure_refs_in_calc(expr.right, prefix_map) - return BinOp(op=expr.op, left=updated_left, right=updated_right) - else: - # Literal number or other - return as-is - return expr + This property makes projection pushdown intrinsic to the join operation, + similar to how `schema` is intrinsic to a relation. + Computes what columns are needed from each leaf table based on: + 1. Columns needed for measures defined on the joined tables + 2. Join key columns -def _extract_join_key_column_names(source: Relation) -> set[str]: - """ - Extract column names that ibis will merge (coalesce) during joins. + Returns: + Dict mapping table names to sets of required column names. + """ + return self._compute_required_columns() - Ibis only merges join-key columns when **both** sides of an equi-join share - the **same** column name (e.g., ``l.code == r.code``). When names differ - (e.g., ``l.carrier == r.code``), the right column gets a ``_right`` suffix - instead. We return only the intersection of left/right key names so that - ``_check_and_add_rename`` correctly detects columns that need renaming. + def _compute_required_columns(self, parent_requirements: dict[str, set[str]] | None = None): + """ + Compute column requirements for projection pushdown. - Args: - source: The relation to search for join operations + Args: + parent_requirements: Optional dict of specific columns requested by parent operations. - Returns: - Set of column names that ibis merges (same-name equi-join keys) - """ - join_keys: set[str] = set() + Returns: + Dict mapping table names to sets of required column names. + """ + # Start with parent requirements using immutable TableRequirements + requirements = projection_utils.TableRequirements.from_dict( + parent_requirements if parent_requirements else {} + ) - def find_joins(node): - """Recursively find join operations and extract merged key columns.""" - if isinstance(node, SemanticJoinOp) and node.on: - try: - left_expr = node.left.to_expr() if hasattr(node.left, "to_expr") else node.left - right_expr = node.right.to_expr() if hasattr(node.right, "to_expr") else node.right - result = _extract_join_key_columns(node.on, left_expr, right_expr) - if result.is_success(): - # ibis merges only same-name equi-join columns - join_keys.update(result.left_columns & result.right_columns) - except (AttributeError, TypeError): - pass + # Get all root models in join tree + all_roots = _find_all_root_models(self) - if hasattr(node, "left") and isinstance(node.left, Relation): - find_joins(node.left) - if hasattr(node, "right") and isinstance(node.right, Relation): - find_joins(node.right) - if hasattr(node, "source") and isinstance(node.source, Relation): - find_joins(node.source) + # Collect leaf tables + def collect_leaf_tables(node): + if isinstance(node, SemanticJoinOp): + return collect_leaf_tables(node.left) + collect_leaf_tables(node.right) + table_name = getattr(node, "name", None) + return [(table_name, _to_untagged(node))] if table_name else [] - find_joins(source) - return join_keys + leaf_tables = collect_leaf_tables(self) + # Group measures by table + measures_by_table = {} + for root in all_roots: + root_measures = _get_field_dict(root, "measures") + if root.name: + measures_by_table[root.name] = root_measures + else: + # Parse prefixed measures (e.g., "marketing.spend") + for measure_name, measure_obj in root_measures.items(): + if "." in measure_name: + table_name = measure_name.split(".", 1)[0] + if table_name not in measures_by_table: + measures_by_table[table_name] = {} + measures_by_table[table_name][measure_name] = measure_obj -def _build_column_rename_map( - all_roots: Sequence[SemanticTable], - field_accessor: callable, - source: Relation | None = None, -) -> dict[str, str]: - """ - Build a mapping of dimension names to their renamed column names in joined tables. + # Extract columns needed by measures (using immutable operations) + for table_name, table_ibis in leaf_tables: + if table_name in measures_by_table: + for measure_obj in measures_by_table[table_name].values(): + # All measure_obj values are Measure instances with expr attribute + measure_fn = measure_obj.expr + if callable(measure_fn): + cols = projection_utils.extract_columns_from_callable_safe( + measure_fn, table_ibis + ) + if cols: + requirements = requirements.add_columns(table_name, cols) - When Ibis joins tables with duplicate column names, it renames columns from later - tables with '_right' suffix. However, columns used as join keys are merged and - NOT renamed, so we exclude them from the rename map. + # Extract and add join key columns + if self.on is not None: + # Get full schema for join key extraction + temp_left = ( + self.left.to_untagged() + if isinstance(self.left, SemanticJoinOp) + else _to_untagged(self.left) + ) + temp_right = ( + self.right.to_untagged() + if isinstance(self.right, SemanticJoinOp) + else _to_untagged(self.right) + ) - Uses graph_utils for generic traversal and the returns library for safe handling. + join_keys_result = _extract_join_key_columns(self.on, temp_left, temp_right) - Args: - all_roots: List of root semantic tables in join order - field_accessor: Function to get fields (dimensions) from a root - source: Optional source relation to extract join keys from + if join_keys_result.is_success(): + # Add join keys to leaf tables (immutable operations) + if isinstance(self.left, SemanticJoinOp): + for col in join_keys_result.left_columns: + for leaf_name in self.left._collect_leaf_table_names(): + leaf_table = self._get_leaf_table_by_name(self.left, leaf_name) + if leaf_table and col in _to_untagged(leaf_table).columns: + requirements = requirements.add_columns(leaf_name, {col}) + else: + left_name = getattr(self.left, "name", None) + if left_name: + requirements = requirements.add_columns( + left_name, join_keys_result.left_columns + ) - Returns: - Dict mapping dimension names like 'airports.city' to renamed columns like 'city_right' - """ - # Build column index using graph_utils (returns Result) - from returns.result import Failure + if isinstance(self.right, SemanticJoinOp): + for col in join_keys_result.right_columns: + for leaf_name in self.right._collect_leaf_table_names(): + leaf_table = self._get_leaf_table_by_name(self.right, leaf_name) + if leaf_table and col in _to_untagged(leaf_table).columns: + requirements = requirements.add_columns(leaf_name, {col}) + else: + right_name = getattr(self.right, "name", None) + if right_name: + requirements = requirements.add_columns( + right_name, join_keys_result.right_columns + ) - from .graph_utils import build_column_index_from_roots, extract_column_from_dimension + return requirements.to_dict() - column_index_result = build_column_index_from_roots(all_roots) - if isinstance(column_index_result, Failure): - # If we can't build the index, return empty map (dimensions will use fallback behavior) - return {} + def _get_leaf_table_by_name(self, join_op: SemanticJoinOp, target_name: str): + """Find a leaf table by name in a join tree.""" + if isinstance(join_op.left, SemanticJoinOp): + result = self._get_leaf_table_by_name(join_op.left, target_name) + if result is not None: + return result + else: + left_name = getattr(join_op.left, "name", None) + if left_name == target_name: + return join_op.left - column_index = column_index_result.value_or({}) + if isinstance(join_op.right, SemanticJoinOp): + result = self._get_leaf_table_by_name(join_op.right, target_name) + if result is not None: + return result + else: + right_name = getattr(join_op.right, "name", None) + if right_name == target_name: + return join_op.right - # Extract join key columns to exclude from renaming - join_keys = _extract_join_key_column_names(source) if source else set() + return None - # Build a map from table name → actual ibis join depth by walking the - # join tree. The flat index in all_roots does NOT equal ibis join depth - # for nested joins (e.g. aircraft → aircraft_models inside a flights - # join tree), so we must compute it from the tree structure. - join_depth_map: dict[str, int] = {} - if source is not None: - join_depth_map = _build_join_depth_map(source) + def _collect_join_keys_for_leaves(self) -> dict[str, set[str]]: + """Collect join keys needed by each leaf table. - # Process dimensions and determine which need renamed columns - rename_map = {} + For nested joins, we trace join keys back to their source leaf tables. + Returns dict mapping leaf table names to sets of columns needed for joins. + """ + join_columns: dict[str, set[str]] = {} - for idx, root in enumerate(all_roots): - if not root.name: - continue + # Recursively collect from nested joins + if isinstance(self.left, SemanticJoinOp): + nested_keys = self.left._collect_join_keys_for_leaves() + for table_name, cols in nested_keys.items(): + existing = join_columns.get(table_name, set()) + join_columns[table_name] = existing | cols - fields_dict = field_accessor(root) - if not fields_dict: - continue + if isinstance(self.right, SemanticJoinOp): + nested_keys = self.right._collect_join_keys_for_leaves() + for table_name, cols in nested_keys.items(): + existing = join_columns.get(table_name, set()) + join_columns[table_name] = existing | cols - root_tbl = root.to_untagged() - # Use the actual join depth if available, otherwise fall back to table_idx - effective_depth = join_depth_map.get(root.name, idx) - - for field_name, field_value in fields_dict.items(): - # Extract column name using graph_utils (returns Maybe) - column_maybe = extract_column_from_dimension(field_value, root_tbl) - - # Use Maybe pattern from returns library - column_maybe.bind_optional( - lambda base_column: _check_and_add_rename( # noqa: B023 - rename_map=rename_map, - base_column=base_column, - prefixed_name=f"{root.name}.{field_name}", # noqa: B023 - table_idx=idx, # noqa: B023 - column_index=column_index, - join_keys=join_keys, - join_depth=effective_depth, # noqa: B023 - ) + # Add join keys for THIS level + if self.on is not None: + # Convert without projection to get full schema + temp_left = ( + self.left.to_untagged(parent_requirements=None) + if isinstance(self.left, SemanticJoinOp) + else _to_untagged(self.left) + ) + temp_right = ( + self.right.to_untagged(parent_requirements=None) + if isinstance(self.right, SemanticJoinOp) + else _to_untagged(self.right) ) - return rename_map - + join_keys = _extract_join_key_columns(self.on, temp_left, temp_right) -def _check_and_add_rename( - rename_map: dict[str, str], - base_column: str, - prefixed_name: str, - table_idx: int, - column_index: dict[str, list[int]], - join_keys: set[str], - join_depth: int | None = None, -) -> None: - """ - Check if a column needs renaming and add to rename map if so. - - ``table_idx`` is the flat index in ``all_roots`` used to detect - whether an earlier table has the same column. ``join_depth`` is - the actual ibis join depth (from ``_build_join_depth_map``) used - to compute the ``_right`` / ``_right2`` / … suffix. - - Args: - rename_map: Map to update with renames - base_column: The base column name - prefixed_name: The prefixed dimension name (e.g., 'airports.city') - table_idx: Flat index in all_roots (for conflict detection) - column_index: Index of column occurrences - join_keys: Set of column names used as join keys (these don't get renamed) - join_depth: Actual ibis join depth for suffix computation (defaults to table_idx) - """ - # Skip columns that are join keys - they get merged, not renamed - if base_column in join_keys: - return + if join_keys.is_success(): + # Add join keys to the appropriate leaf tables + if not isinstance(self.left, SemanticJoinOp): + # Left is a leaf table + left_name = getattr(self.left, "name", None) + if left_name: + existing = join_columns.get(left_name, set()) + join_columns[left_name] = existing | join_keys.left_columns + else: + # Left is a nested join - need to map columns back to source tables + # Get all leaf tables from the nested join and their schemas + left_leaves = self.left._collect_leaf_table_names() + for col in join_keys.left_columns: + # Add column to each leaf table that actually has this column + for table_name in left_leaves: + if table_name: + # Check if this table actually has this column + # We do this by converting the table and checking its schema + leaf_table = self._get_leaf_table_by_name(self.left, table_name) + if leaf_table is not None: + leaf_ibis = _to_untagged(leaf_table) + if col in leaf_ibis.columns: + existing = join_columns.get(table_name, set()) + join_columns[table_name] = existing | {col} - depth = join_depth if join_depth is not None else table_idx + if not isinstance(self.right, SemanticJoinOp): + # Right is a leaf table + right_name = getattr(self.right, "name", None) + if right_name: + existing = join_columns.get(right_name, set()) + join_columns[right_name] = existing | join_keys.right_columns + else: + # Right is a nested join + right_leaves = self.right._collect_leaf_table_names() + for col in join_keys.right_columns: + for table_name in right_leaves: + if table_name: + leaf_table = self._get_leaf_table_by_name(self.right, table_name) + if leaf_table is not None: + leaf_ibis = _to_untagged(leaf_table) + if col in leaf_ibis.columns: + existing = join_columns.get(table_name, set()) + join_columns[table_name] = existing | {col} - if base_column in column_index: - tables_with_column = column_index[base_column] - # Check if any table before this one (in flat order) has the same column - earlier_tables = [t for t in tables_with_column if t < table_idx] - if earlier_tables: - suffix = "_right" if depth <= 1 else f"_right{depth}" - rename_map[prefixed_name] = f"{base_column}{suffix}" + return join_columns + def _augment_parent_requirements_for_pruning( + self, + parent_requirements: dict[str, set[str]] | None, + ) -> dict[str, set[str]] | None: + """Expand table requirements with bridge tables needed by descendant joins.""" + if parent_requirements is None: + return None -def _wrap_dimension_for_renamed_column(dimension: Dimension, renamed_column: str) -> Dimension: - """ - Wrap a dimension to access a renamed column in a joined table. + augmented = {table: set(cols) for table, cols in parent_requirements.items()} - Args: - dimension: The original dimension - renamed_column: The renamed column name (e.g., 'city_right') + right_tables = ( + self.right._collect_leaf_table_names() + if isinstance(self.right, SemanticJoinOp) + else {getattr(self.right, "name", None)} - {None} + ) + if augmented.keys() & right_tables and self.on is not None: + temp_left = ( + self.left.to_untagged(parent_requirements=None) + if isinstance(self.left, SemanticJoinOp) + else _to_untagged(self.left) + ) + temp_right = ( + self.right.to_untagged(parent_requirements=None) + if isinstance(self.right, SemanticJoinOp) + else _to_untagged(self.right) + ) + join_keys = _extract_join_key_columns(self.on, temp_left, temp_right) + if join_keys.is_success(): + join_key_columns = join_keys.left_columns | join_keys.right_columns + if isinstance(self.left, SemanticJoinOp): + for col in join_keys.left_columns: + for leaf_name in self.left._collect_leaf_table_names(): + leaf_table = self._get_leaf_table_by_name(self.left, leaf_name) + if leaf_table is not None and col in _to_untagged(leaf_table).columns: + augmented.setdefault(leaf_name, set()) + else: + left_name = getattr(self.left, "name", None) + if left_name: + augmented.setdefault(left_name, set()) - Returns: - A new Dimension that accesses the renamed column - """ + right_needed_tables = right_tables & augmented.keys() + if isinstance(self.left, SemanticJoinOp) and right_needed_tables: + left_leaf_names = self.left._collect_leaf_table_names() + right_needed_leaf_names = [ + leaf_name for leaf_name in right_needed_tables + if self._get_leaf_table_by_name(self, leaf_name) is not None + ] + for left_leaf_name in left_leaf_names: + left_leaf = self._get_leaf_table_by_name(self.left, left_leaf_name) + if left_leaf is None: + continue + left_columns = set(_to_untagged(left_leaf).columns) - join_key_columns + if not left_columns: + continue + for right_leaf_name in right_needed_leaf_names: + right_leaf = self._get_leaf_table_by_name(self, right_leaf_name) + if right_leaf is None: + continue + right_columns = set(_to_untagged(right_leaf).columns) - join_key_columns + if left_columns & right_columns: + augmented.setdefault(left_leaf_name, set()) + break - # Create a new callable that accesses the renamed column - def renamed_accessor(table: ir.Table) -> ir.Value: - return table[renamed_column] - - # Return a new Dimension with the wrapped callable but same metadata - return Dimension( - expr=renamed_accessor, - description=dimension.description, - is_entity=dimension.is_entity, - is_time_dimension=dimension.is_time_dimension, - is_event_timestamp=dimension.is_event_timestamp, - smallest_time_grain=dimension.smallest_time_grain, - derived_dimensions=dimension.derived_dimensions, - ) + if isinstance(self.left, SemanticJoinOp): + augmented = self.left._augment_parent_requirements_for_pruning(augmented) or augmented + if isinstance(self.right, SemanticJoinOp): + augmented = self.right._augment_parent_requirements_for_pruning(augmented) or augmented + return augmented -def _merge_fields_with_prefixing( - all_roots: Sequence[SemanticTable], - field_accessor: callable, - source: Relation | None = None, -) -> FrozenDict[str, Any]: - """ - Generic function to merge any type of fields (dimensions, measures) with prefixing. + def _should_prune_right(self, parent_requirements: dict[str, set[str]] | None) -> bool: + """Return whether the right side can be skipped for the given requirements.""" + if ( + parent_requirements is None + or self.cardinality != "one" + or self.how != "left" + ): + return False - Args: - all_roots: List of SemanticTable root models - field_accessor: Function that takes a root and returns the fields dict (e.g. lambda r: r.dimensions) - source: Optional source relation to extract join keys from for proper column renaming + needed_tables = frozenset(parent_requirements.keys()) + right_tables = ( + self.right._collect_leaf_table_names() + if isinstance(self.right, SemanticJoinOp) + else {getattr(self.right, "name", None)} - {None} + ) + if not right_tables or right_tables & needed_tables: + return False - Returns: - FrozenDict mapping field names (always prefixed with table name) to field values - """ - if not all_roots: - return FrozenDict() - - merged_fields = {} - - is_calc_measures = False - is_dimensions = False - if all_roots: - sample_fields = field_accessor(all_roots[0]) - if sample_fields: - from .measure_scope import AllOf, BinOp, MeasureRef, MethodCall - - first_val = next(iter(sample_fields.values()), None) - is_calc_measures = isinstance( - first_val, - MeasureRef | AllOf | BinOp | MethodCall | int | float, + # If an earlier sibling shares non-join columns with a still-needed + # table, pruning it changes the later table's rname-based aliases + # (e.g. ``state_right2`` vs ``state_right``). Keep the sibling so + # wrapped dimensions keep resolving to the expected columns. + remaining_tables = needed_tables - right_tables + if remaining_tables: + temp_left = ( + self.left.to_untagged(parent_requirements=None) + if isinstance(self.left, SemanticJoinOp) + else _to_untagged(self.left) ) - is_dimensions = isinstance(first_val, Dimension) - - # For dimensions, build a column rename map to handle Ibis join conflicts - column_rename_map = {} - if is_dimensions: - column_rename_map = _build_column_rename_map(all_roots, field_accessor, source) - - for root in all_roots: - root_name = root.name - fields_dict = field_accessor(root) - - if is_calc_measures and root_name: - base_map = ( - {k: f"{root_name}.{k}" for k in root.get_measures()} - if hasattr(root, "get_measures") - else {} + temp_right = ( + self.right.to_untagged(parent_requirements=None) + if isinstance(self.right, SemanticJoinOp) + else _to_untagged(self.right) ) - calc_map = ( - {k: f"{root_name}.{k}" for k in root.get_calculated_measures()} - if hasattr(root, "get_calculated_measures") - else {} + join_keys = _extract_join_key_columns(self.on, temp_left, temp_right) + join_key_columns = ( + (join_keys.left_columns | join_keys.right_columns) + if self.on is not None and join_keys.is_success() + else set() ) - prefix_map = {**base_map, **calc_map} - - for field_name, field_value in fields_dict.items(): - if root_name: - # Always use prefixed name with . separator - prefixed_name = f"{root_name}.{field_name}" - - # If it's a calculated measure, update internal MeasureRefs - if is_calc_measures: - field_value = _update_measure_refs_in_calc(field_value, prefix_map) - # If it's a dimension that needs column renaming, wrap the callable - elif is_dimensions and prefixed_name in column_rename_map: - field_value = _wrap_dimension_for_renamed_column( - field_value, column_rename_map[prefixed_name] - ) - - merged_fields[prefixed_name] = field_value - else: - # Fallback to original name if no root name - merged_fields[field_name] = field_value - return FrozenDict(merged_fields) + for right_name in right_tables: + right_leaf = self._get_leaf_table_by_name(self, right_name) + if right_leaf is None: + continue + right_columns = set(_to_untagged(right_leaf).columns) - join_key_columns + if not right_columns: + continue + for needed_name in remaining_tables: + needed_leaf = self._get_leaf_table_by_name(self, needed_name) + if needed_leaf is None: + continue + needed_columns = set(_to_untagged(needed_leaf).columns) - join_key_columns + if right_columns & needed_columns: + return False + return True -# ============================================================================== -# Column Tracking for Projection Pushdown -# ============================================================================== + def _effective_join_depth(self, parent_requirements: dict[str, set[str]] | None) -> int: + """Count the surviving left-spine joins after pruning.""" + augmented = self._augment_parent_requirements_for_pruning(parent_requirements) + left_depth = ( + self.left._effective_join_depth(augmented) + if isinstance(self.left, SemanticJoinOp) + else 0 + ) + if self._should_prune_right(augmented): + return left_depth + return left_depth + 1 + @staticmethod + def _join_depth(op) -> int: + """Count nested left SemanticJoinOps to determine join depth.""" + depth = 0 + current = op + while isinstance(current, SemanticJoinOp): + depth += 1 + current = current.left + return depth -@frozen -class ColumnTracker: - """Immutable tracker for column references during expression evaluation. + @staticmethod + def _rname_for_depth(depth: int) -> str: + """Return the ``rname`` template for the given join depth. - Uses frozenset for tracked columns. New columns are added by creating - new tracker instances with updated sets. - """ + ibis uses ``{name}_right`` by default. When three or more tables + share a column name the second ``_right`` collides with the first. + We avoid this by appending the depth: ``_right``, ``_right2``, + ``_right3``, … + """ + return "{name}_right" if depth <= 1 else f"{{name}}_right{depth}" - columns: frozenset[str] = field(factory=frozenset, converter=frozenset) + def to_untagged(self, parent_requirements: dict[str, set[str]] | None = None): + """Convert join to Ibis expression, pruning unnecessary joins. - def with_column(self, col_name: str) -> ColumnTracker: - """Return new tracker with additional column.""" - return ColumnTracker(columns=self.columns | {col_name}) + When *parent_requirements* is provided (a dict mapping table names to + sets of needed columns), right-side leaf tables that contribute no + required columns are skipped entirely. This avoids expensive joins + to dimension tables whose columns are never referenced by the query. - def merge(self, other: ColumnTracker) -> ColumnTracker: - """Return new tracker with merged columns.""" - return ColumnTracker(columns=self.columns | other.columns) + Args: + parent_requirements: Optional mapping of ``{table_name: {col, …}}``. + When provided, joins to tables absent from this dict are elided. + Returns: + Ibis join expression (potentially simplified). + """ + from ..convert import _Resolver -@frozen -class ColumnExtractionResult: - """Result of column extraction with error handling. + augmented_requirements = self._augment_parent_requirements_for_pruning(parent_requirements) - Separates successful extraction from error cases. - """ + # --- Join pruning: skip right-side tables not needed by the query --- + # Only prune join_one with LEFT join semantics. Inner joins act as + # a filter (excluding unmatched left rows) and must not be removed. + # join_many / join_cross affect row counts and are never pruned here + # (join_many is intercepted by the pre-aggregation path earlier). + if self._should_prune_right(augmented_requirements): + # Right side contributes nothing the parent needs — skip it. + if isinstance(self.left, SemanticJoinOp): + return self.left.to_untagged(parent_requirements=augmented_requirements) + return _to_untagged(self.left) - columns: frozenset[str] = field(factory=frozenset, converter=frozenset) - extraction_failed: bool = False - error_type: type[Exception] | None = None - - @classmethod - def success(cls, columns: set[str] | frozenset[str]) -> ColumnExtractionResult: - """Create successful result.""" - return cls(columns=frozenset(columns), extraction_failed=False) - - @classmethod - def failure(cls, error: Exception) -> ColumnExtractionResult: - """Create failure result with error information.""" - return cls( - columns=frozenset(), - extraction_failed=True, - error_type=type(error), + # Build both sides, passing requirements down for recursive pruning. + left_tbl = ( + _to_untagged(self.left) + if not isinstance(self.left, SemanticJoinOp) + else self.left.to_untagged(parent_requirements=augmented_requirements) + ) + right_tbl = ( + _to_untagged(self.right) + if not isinstance(self.right, SemanticJoinOp) + else self.right.to_untagged(parent_requirements=augmented_requirements) ) - def is_success(self) -> bool: - """Check if extraction succeeded.""" - return not self.extraction_failed - + # Rebind right side's DatabaseTable ops to use the same backend as + # the left side. from_ibis() creates a separate Backend object per + # call; xorq >=0.3.11 raises "Multiple backends found" unless all + # tables in a join share the same backend instance. + left_tbl, right_tbl = self._rebind_join_backends(left_tbl, right_tbl) -@frozen -class JoinColumnExtractionResult: - """Result of join column extraction for both tables.""" - - left_columns: frozenset[str] = field(factory=frozenset, converter=frozenset) - right_columns: frozenset[str] = field(factory=frozenset, converter=frozenset) - extraction_failed: bool = False - error_type: type[Exception] | None = None - - @classmethod - def success( - cls, - left: set[str] | frozenset[str], - right: set[str] | frozenset[str], - ) -> JoinColumnExtractionResult: - """Create successful result.""" - return cls( - left_columns=frozenset(left), - right_columns=frozenset(right), - extraction_failed=False, + depth = ( + self._effective_join_depth(augmented_requirements) + if augmented_requirements is not None + else self._join_depth(self) ) + rname = self._rname_for_depth(depth) - @classmethod - def failure(cls, error: Exception) -> JoinColumnExtractionResult: - """Create failure result with error information.""" - return cls( - left_columns=frozenset(), - right_columns=frozenset(), - extraction_failed=True, - error_type=type(error), - ) + if self.on is None: + return left_tbl.join(right_tbl, how=self.how, rname=rname) - def is_success(self) -> bool: - """Check if extraction succeeded.""" - return not self.extraction_failed + # Detect column name conflicts that cause ibis/xorq to raise + # "Ambiguous field reference" during predicate resolution. + conflicting = frozenset(left_tbl.columns) & frozenset(right_tbl.columns) + if not conflicting: + pred = self.on(_Resolver(left_tbl), _Resolver(right_tbl)) + return left_tbl.join(right_tbl, pred, how=self.how, rname=rname) -def _make_tracking_proxy( - table: ir.Table, - on_access: Callable[[str], None], -) -> Any: - """Create tracking proxy with custom access handler. + # Temporarily rename conflicting left columns so the predicate + # can be resolved without ambiguity. + # ibis rename convention: {new_name: old_name} + rename_left = {f"{_BSL_JOIN_KEY_TMP_PREFIX}{c}": c for c in conflicting} + left_safe = left_tbl.rename(rename_left) - Composable factory that enables different tracking strategies - via the on_access callback. - """ + # Resolver that transparently maps original names → temp names, + # so predicates like ``lambda f, a: f.tail_num == a.tail_num`` + # still work even though left's ``tail_num`` was renamed. + orig_to_tmp = {c: f"{_BSL_JOIN_KEY_TMP_PREFIX}{c}" for c in conflicting} - class _TrackingProxy: - """Proxy that tracks attribute and item access.""" + pred = self.on( + _RenamedResolver(left_safe, orig_to_tmp), + _Resolver(right_tbl), + ) + joined = left_safe.join(right_tbl, pred, how=self.how, rname=rname) - def __init__(self, inner_table: ir.Table, access_handler: Callable[[str], None]): - object.__setattr__(self, "_table", inner_table) - object.__setattr__(self, "_on_access", access_handler) + # Restore final column names (ibis convention: {new: old}): + # - left temp columns → original names + # - right conflicting columns → depth-based rname suffix + rename_final = {c: f"{_BSL_JOIN_KEY_TMP_PREFIX}{c}" for c in conflicting} | { + rname.replace("{name}", c): c for c in conflicting + } - def __getattr__(self, name: str): - if name.startswith("_"): - return getattr(self._table, name) - self._on_access(name) - return getattr(self._table, name) + return joined.rename(rename_final) - def __getitem__(self, name: str): - self._on_access(name) - return self._table[name] + @staticmethod + def _rebind_join_backends(left_tbl, right_tbl): + """Rebind DatabaseTable ops so both sides share a single backend. - return _TrackingProxy(table, on_access) + When tables are individually wrapped via ``from_ibis()``, each gets + a distinct ``Backend`` object. xorq >=0.3.11 raises "Multiple + backends found" unless all tables in a join share the same instance. + Uses ``op.replace()`` (ibis graph rewriting) to swap out the + ``source`` field on every ``DatabaseTable`` node in the right tree. + For plain ibis expressions (backends xorq doesn't wrap, e.g. + BigQuery), ``walk_nodes`` can't traverse the tree — fall back to + returning the inputs unchanged so ibis executes the join natively. + """ + try: + from .._xorq import relations as xorq_rel, walk_nodes + except ImportError: + return left_tbl, right_tbl -def _extract_columns_from_callable( - fn: Any, - table: ir.Table, -) -> ColumnExtractionResult: - """Extract column names referenced by a callable. + # Find a canonical backend from the left tree. + try: + db_tables = list(walk_nodes((xorq_rel.DatabaseTable,), left_tbl)) + except Exception: + # Plain ibis tree — no DatabaseTable nodes for xorq to walk. + logger.debug("walk_nodes failed on plain ibis left side", exc_info=True) + return left_tbl, right_tbl + canonical = db_tables[0].source if db_tables else None - Uses immutable tracking and returns structured result. - """ - if not callable(fn): - return ColumnExtractionResult.success(frozenset()) + if canonical is None: + return left_tbl, right_tbl - tracker_ref = [ColumnTracker()] + def _recreate(op, _kwargs, **overrides): + kwargs = dict(zip(op.__argnames__, op.__args__, strict=False)) + if _kwargs: + kwargs.update(_kwargs) + kwargs.update(overrides) + return op.__recreate__(kwargs) - def on_column_access(col_name: str) -> None: - tracker_ref[0] = tracker_ref[0].with_column(col_name) + def replacer(op, _kwargs): + if isinstance(op, xorq_rel.DatabaseTable) and op.source is not canonical: + return _recreate(op, _kwargs, source=canonical) + # Propagate rewritten children (e.g. SelfReference wrapping + # a replaced DatabaseTable). + if _kwargs: + return _recreate(op, _kwargs) + return op - try: - tracking_proxy = _make_tracking_proxy(table, on_column_access) - fn(tracking_proxy) - return ColumnExtractionResult.success(tracker_ref[0].columns) + new_left = left_tbl.op().replace(replacer).to_expr() + new_right = right_tbl.op().replace(replacer).to_expr() + return new_left, new_right - except Exception as e: - return ColumnExtractionResult.failure(e) + def execute(self): + return _rebind_to_canonical_backend(self.to_untagged()).execute() + def compile(self, **kwargs): + return _rebind_to_canonical_backend(self.to_untagged()).compile(**kwargs) -def _extract_join_key_columns( - on: Callable[[Any, Any], ir.BooleanValue], - left_table: ir.Table, - right_table: ir.Table, -) -> JoinColumnExtractionResult: - left_tracker_ref = [ColumnTracker()] - right_tracker_ref = [ColumnTracker()] + def sql(self, **kwargs): + return ibis.to_sql(_rebind_to_canonical_backend(self.to_untagged()), **kwargs) - def on_left_access(col_name: str) -> None: - left_tracker_ref[0] = left_tracker_ref[0].with_column(col_name) + def __getitem__(self, key): + dims_dict = self.get_dimensions() + if key in dims_dict: + return dims_dict[key] - def on_right_access(col_name: str) -> None: - right_tracker_ref[0] = right_tracker_ref[0].with_column(col_name) + meas_dict = self.get_measures() + if key in meas_dict: + return meas_dict[key] - try: - left_proxy = _make_tracking_proxy(left_table, on_left_access) - right_proxy = _make_tracking_proxy(right_table, on_right_access) - on(left_proxy, right_proxy) + calc_meas_dict = self.get_calculated_measures() + if key in calc_meas_dict: + return calc_meas_dict[key] - return JoinColumnExtractionResult.success( - left_tracker_ref[0].columns, - right_tracker_ref[0].columns, + raise KeyError( + f"'{key}' not found in dimensions, measures, or calculated measures", ) - except Exception as e: - return JoinColumnExtractionResult.failure(e) - - -# ============================================================================== -# Table Column Requirements -# ============================================================================== - - -@frozen -class TableColumnRequirements: - """Immutable representation of column requirements per table. - - Maps table names to sets of required column names. - """ - - requirements: FrozenDict[str, frozenset[str]] = field( - factory=lambda: FrozenDict({}), - converter=lambda d: FrozenDict( - {k: frozenset(v) if not isinstance(v, frozenset) else v for k, v in d.items()}, - ), - ) - - def with_column(self, table_name: str, col_name: str) -> TableColumnRequirements: - """Return new requirements with additional column for table.""" - current_cols = self.requirements.get(table_name, frozenset()) - updated_cols = current_cols | {col_name} - - return TableColumnRequirements( - requirements=dict(self.requirements) | {table_name: updated_cols}, - ) + def pipe(self, func, *args, **kwargs): + return func(self, *args, **kwargs) - def with_columns( - self, - table_name: str, - col_names: Iterable[str], - ) -> TableColumnRequirements: - """Return new requirements with multiple columns for table.""" - current_cols = self.requirements.get(table_name, frozenset()) - updated_cols = current_cols | frozenset(col_names) - - return TableColumnRequirements( - requirements=dict(self.requirements) | {table_name: updated_cols}, + def as_table(self) -> SemanticTable: + """Convert to SemanticTable, preserving merged metadata from both sides.""" + return _semantic_table( + table=self.to_untagged(), + dimensions=self.get_dimensions(), + measures=self.get_measures(), + calc_measures=self.get_calculated_measures(), ) - def merge(self, other: TableColumnRequirements) -> TableColumnRequirements: - """Merge requirements from another instance.""" - merged_dict = dict(self.requirements) - - for table, cols in other.requirements.items(): - if table in merged_dict: - merged_dict[table] = merged_dict[table] | cols - else: - merged_dict[table] = cols - - return TableColumnRequirements(requirements=merged_dict) - - def to_dict(self) -> dict[str, set[str]]: - """Convert to mutable dict for API compatibility.""" - return {table: set(cols) for table, cols in self.requirements.items()} - - -def _parse_prefixed_field(field_name: str) -> tuple[str | None, str]: - """Parse potentially prefixed field name. - - Args: - field_name: Field name, possibly prefixed (e.g., "table.column") - - Returns: - Tuple of (table_name or None, column_name) - """ - if "." in field_name: - table, col = field_name.split(".", 1) - return (table, col) - return (None, field_name) - - -def _extract_requirements_from_keys( - keys: Iterable[str], - merged_dimensions: Mapping[str, Any], - all_roots: Sequence[Any], - table: ir.Table, -) -> TableColumnRequirements: - """Extract column requirements from group-by keys using graph traversal.""" - requirements = TableColumnRequirements() - - for key in keys: - table_name, col_name = _parse_prefixed_field(key) - - if table_name: - # Prefixed: we know the table - requirements = requirements.with_column(table_name, col_name) - else: - # Unprefixed: resolve dimension or use conservative fallback - if key in merged_dimensions: - dim_fn = merged_dimensions[key] - - try: - # Evaluate the dimension to get an Ibis expression - dim_expr = dim_fn(table) - - # Walk the expression graph to find all Field nodes (column references) - field_names = {node.name for node in walk_nodes(ibis_ops.Field, dim_expr)} - - # Filter to only actual columns in the table schema - actual_cols = {col for col in field_names if col in table.columns} - - if actual_cols: - for root in all_roots: - if root.name: - requirements = requirements.with_columns(root.name, actual_cols) - else: - # Fallback: assume key name is column name - for root in all_roots: - if root.name: - requirements = requirements.with_column(root.name, key) - except Exception: - logger.debug( - "dimension graph traversal failed for %r; " - "treating key name as column name", - key, - exc_info=True, - ) - for root in all_roots: - if root.name: - requirements = requirements.with_column(root.name, key) - else: - # Raw column - for root in all_roots: - if root.name: - requirements = requirements.with_column(root.name, key) - - return requirements +from ._order_limit import SemanticLimitOp, SemanticOrderByOp # noqa: E402, F401 -def _extract_requirements_from_measures( - aggs: Mapping[str, Callable], - all_roots: Sequence[Any], - table: ir.Table, -) -> TableColumnRequirements: - """Extract column requirements from measure aggregations using graph traversal.""" - requirements = TableColumnRequirements() - for measure_name, measure_fn in aggs.items(): - fn = _unwrap(measure_fn) - - try: - # Evaluate the measure to get an Ibis expression - measure_expr = fn(table) - - # Walk the expression graph to find all Field nodes (column references) - field_names = {node.name for node in walk_nodes(ibis_ops.Field, measure_expr)} +from ._index_op import ( # noqa: E402 + SemanticIndexOp, + _build_numeric_index_fragment, + _build_string_index_fragment, + _get_field_type_str, + _get_fields_to_index, + _get_weight_expr, + _resolve_selector, +) - # Filter to only actual columns in the table schema - actual_cols = {col for col in field_names if col in table.columns} - if actual_cols: - for root in all_roots: - if root.name: - requirements = requirements.with_columns(root.name, actual_cols) - except Exception: - logger.debug( - "measure graph traversal failed for %r; " - "falling back to name-based column inference", - measure_name, - exc_info=True, - ) - # Conservative fallback: if measure name looks like a column, include it - if measure_name.isidentifier(): - for root in all_roots: - if root.name: - requirements = requirements.with_column(root.name, measure_name) +from ._root_models import ( # noqa: E402 + _build_column_rename_map, + _build_join_depth_map, + _check_and_add_rename, + _dimension_only_source_table, + _extract_join_key_column_names, + _find_all_root_models, + _find_root_model, + _merge_fields_with_prefixing, + _update_measure_refs_in_calc, + _wrap_dimension_for_renamed_column, +) - return requirements +# Column-extraction classes & helpers were moved to ._column_extraction. +# Re-imported here so internal callers in this module keep working. +from ._column_extraction import ( # noqa: E402 + ColumnExtractionResult, + ColumnTracker, + JoinColumnExtractionResult, + TableColumnRequirements, + _extract_columns_from_callable, + _extract_join_key_columns, + _extract_requirements_from_keys, + _extract_requirements_from_measures, + _make_tracking_proxy, + _parse_prefixed_field, +) diff --git a/src/boring_semantic_layer/ops/_format.py b/src/boring_semantic_layer/ops/_format.py new file mode 100644 index 0000000..5c450b5 --- /dev/null +++ b/src/boring_semantic_layer/ops/_format.py @@ -0,0 +1,160 @@ +"""Pretty-printing of semantic-layer operation chains. + +Used by every ``Semantic*Op.__repr__``. The Op classes themselves are +imported lazily inside each helper to avoid a circular module +dependency with ``_core``. +""" + +from __future__ import annotations + +from ibis.expr.operations.relations import Relation + + +def _collect_chain(op: Relation) -> list[Relation]: + """Walk op.source (or op.left for joins) back to root, return list from root to current.""" + chain = [op] + current = op + while True: + if hasattr(current, "source") and current.source is not None: + chain.append(current.source) + current = current.source + elif hasattr(current, "left") and current.left is not None: + chain.append(current.left) + current = current.left + else: + break + chain.reverse() + return chain + + +def _format_op_summary(op: Relation) -> str: + """Return a one-line summary string for a non-root semantic op.""" + from ._core import ( + SemanticAggregateOp, + SemanticFilterOp, + SemanticGroupByOp, + SemanticIndexOp, + SemanticJoinOp, + SemanticLimitOp, + SemanticMutateOp, + SemanticOrderByOp, + SemanticProjectOp, + SemanticTableOp, + SemanticUnnestOp, + ) + from ._root_models import _find_all_root_models + + cls = type(op).__name__ + + if isinstance(op, SemanticFilterOp): + predicate = object.__getattribute__(op, "predicate") + pred_name = "" + if hasattr(predicate, "__name__"): + pred_name = predicate.__name__ + elif hasattr(predicate, "unwrap"): + unwrapped = predicate.unwrap + if hasattr(unwrapped, "__name__"): + pred_name = unwrapped.__name__ + return f"Filter(λ {pred_name})" + + if isinstance(op, SemanticMutateOp): + post = object.__getattribute__(op, "post") + cols = list(post.keys()) + return f"Mutate({', '.join(cols)})" + + if isinstance(op, SemanticGroupByOp): + keys = object.__getattribute__(op, "keys") + return f"GroupBy({', '.join(keys)})" + + if isinstance(op, SemanticAggregateOp): + aggs = object.__getattribute__(op, "aggs") + agg_names = list(aggs.keys()) + return f"Aggregate({', '.join(agg_names)})" + + if isinstance(op, SemanticOrderByOp): + keys = object.__getattribute__(op, "keys") + key_strs = [k if isinstance(k, str) else repr(k) for k in keys] + return f"OrderBy({', '.join(key_strs)})" + + if isinstance(op, SemanticLimitOp): + return f"Limit({op.n})" + + if isinstance(op, SemanticProjectOp): + fields = object.__getattribute__(op, "fields") + return f"Project({', '.join(fields)})" + + if isinstance(op, SemanticUnnestOp): + column = object.__getattribute__(op, "column") + return f"Unnest({column})" + + if isinstance(op, SemanticJoinOp): + how = object.__getattribute__(op, "how") + right = object.__getattribute__(op, "right") + right_name = "" + if isinstance(right, SemanticTableOp): + right_name = object.__getattribute__(right, "name") or "" + if not right_name: + # Try to find a root name from right side + roots = _find_all_root_models(right) + if roots: + right_name = object.__getattribute__(roots[0], "name") or "" + if right_name: + return f"Join({how}, right={right_name})" + return f"Join({how})" + + if isinstance(op, SemanticIndexOp): + parts = [] + selector = object.__getattribute__(op, "selector") + by = object.__getattribute__(op, "by") + sample = object.__getattribute__(op, "sample") + if selector is not None: + if isinstance(selector, tuple): + parts.append(", ".join(selector)) + else: + parts.append(str(selector)) + if by is not None: + parts.append(f"by={by}") + if sample is not None: + parts.append(f"sample={sample}") + return f"Index({', '.join(parts)})" + + # Fallback for unknown op types + return cls.replace("Semantic", "").replace("Op", "") + + +def _format_root(root_op) -> str: + """Format a SemanticTableOp root using the fmt registry from format.py.""" + from boring_semantic_layer.format import fmt + + try: + return fmt(root_op) + except Exception: + # Fallback if format module isn't available + name = object.__getattribute__(root_op, "name") + return f"SemanticTable: {name}" if name else "SemanticTable" + + +def _semantic_repr(op: Relation) -> str: + from ._core import SemanticTableOp + + chain = _collect_chain(op) + + # Find the root (first element should be a SemanticTableOp) + root = chain[0] + if isinstance(root, SemanticTableOp): + lines = [_format_root(root)] + else: + # Fallback: no SemanticTableOp root found + from ibis.expr.format import pretty + + try: + return pretty(op) + except Exception: + return object.__repr__(op) + + # Append pipeline steps + for step in chain[1:]: + if not isinstance(step, SemanticTableOp): + lines.append(f"-> {_format_op_summary(step)}") + + return "\n".join(lines) diff --git a/src/boring_semantic_layer/ops/_index_op.py b/src/boring_semantic_layer/ops/_index_op.py new file mode 100644 index 0000000..5013104 --- /dev/null +++ b/src/boring_semantic_layer/ops/_index_op.py @@ -0,0 +1,321 @@ +"""``SemanticIndexOp`` and its index-fragment helpers. + +``index()`` builds a tall ``(fieldName, fieldPath, fieldType, fieldValue, +weight)`` table for each indexed dimension. String fields produce one row +per distinct value; numeric fields collapse to a single ``min .. max`` +row. The result is a ``UNION ALL`` of these per-field fragments. +""" + +from __future__ import annotations + +from collections.abc import Callable +from functools import reduce +from typing import TYPE_CHECKING, Any + +import ibis +from ibis.expr import types as ir +from ibis.expr.operations.relations import Relation + +from .._xorq import FrozenOrderedDict, Schema, selectors as s +from ._format import _semantic_repr +from ._root_models import _find_all_root_models +from ._xorq_compat import _rebind_to_canonical_backend + +if TYPE_CHECKING: + from ..expr import SemanticFilter, SemanticLimit, SemanticOrderBy + + +def _get_field_type_str(field_type: Any) -> str: + return ( + "string" + if field_type.is_string() + else "number" + if field_type.is_numeric() + else "date" + if field_type.is_temporal() + else str(field_type) + ) + + +def _get_weight_expr( + base_tbl: Any, + by_measure: str | None, + all_roots: list, + is_string: bool, +) -> Any: + from .._xorq import api as xo + from ._core import _get_merged_fields + + if not by_measure: + return xo._.count() + + merged_measures = _get_merged_fields(all_roots, "measures") + return ( + merged_measures[by_measure](base_tbl) if by_measure in merged_measures else xo._.count() + ) + + +def _build_string_index_fragment( + base_tbl: Any, + field_expr: Any, + field_name: str, + field_path: str, + type_str: str, + weight_expr: Any, +) -> Any: + from .._xorq import api as xo + + return ( + base_tbl.group_by(field_expr.name("value")) + .aggregate(weight=weight_expr) + .select( + fieldName=xo.literal(field_name.split(".")[-1]), + fieldPath=xo.literal(field_path), + fieldType=xo.literal(type_str), + fieldValue=xo._["value"].cast("string"), + weight=xo._["weight"], + ) + ) + + +def _build_numeric_index_fragment( + base_tbl: Any, + field_expr: Any, + field_name: str, + field_path: str, + type_str: str, + weight_expr: Any, +) -> Any: + from .._xorq import api as xo + + return ( + base_tbl.select(field_expr.name("value")) + .filter(xo._["value"].notnull()) + .aggregate( + min_val=xo._["value"].min(), + max_val=xo._["value"].max(), + weight=weight_expr, + ) + .select( + fieldName=xo.literal(field_name.split(".")[-1]), + fieldPath=xo.literal(field_path), + fieldType=xo.literal(type_str), + fieldValue=( + xo._["min_val"].cast("string") + " to " + xo._["max_val"].cast("string") + ), + weight=xo._["weight"], + ) + ) + + +def _resolve_selector( + selector: str | list[str] | Callable | None, + base_tbl: ir.Table, +) -> tuple[str, ...]: + if selector is None: + return tuple(base_tbl.columns) + try: + selected = base_tbl.select(selector) + return tuple(selected.columns) + except Exception: + return [] + + +def _get_fields_to_index( + selector: str | list[str] | Callable | None, + merged_dimensions: dict, + base_tbl: ir.Table, +) -> tuple[str, ...]: + if selector is None: + selector = s.all() + + raw_fields = _resolve_selector(selector, base_tbl) + + if not raw_fields: + result = list(merged_dimensions.keys()) + result.extend(col for col in base_tbl.columns if col not in result) + else: + result = [col for col in raw_fields if col in merged_dimensions or col in base_tbl.columns] + + return result + + +class SemanticIndexOp(Relation): + source: Relation + selector: str | list[str] | tuple[str, ...] | Callable | None + by: str | None = None + sample: int | None = None + + def __init__( + self, + source: Relation, + selector: str | list[str] | tuple[str, ...] | Callable | None = None, + by: str | None = None, + sample: int | None = None, + ) -> None: + from ._core import _get_merged_fields + + # Validate sample parameter + if sample is not None and sample <= 0: + raise ValueError(f"sample must be positive, got {sample}") + + # Validate 'by' measure exists if provided + if by is not None: + all_roots = _find_all_root_models(source) + if all_roots: + merged_measures = _get_merged_fields(all_roots, "measures") + if by not in merged_measures: + available = list(merged_measures.keys()) + raise KeyError( + f"Unknown measure '{by}' for weight calculation. " + f"Available measures: {', '.join(available) or 'none'}", + ) + + # Convert selector to tuple if it's a list (Ibis requires hashable types) + hashable_selector = tuple(selector) if isinstance(selector, list) else selector + + super().__init__( + source=Relation.__coerce__(source), + selector=hashable_selector, + by=by, + sample=sample, + ) + + def __repr__(self) -> str: + return _semantic_repr(self) + + @property + def values(self) -> FrozenOrderedDict[str, Any]: + from .._xorq import api as xo + + return FrozenOrderedDict( + { + "fieldName": xo.literal("").op(), + "fieldPath": xo.literal("").op(), + "fieldType": xo.literal("").op(), + "fieldValue": xo.literal("").op(), + "weight": xo.literal(0).op(), + }, + ) + + @property + def schema(self) -> Schema: + return Schema( + { + "fieldName": "string", + "fieldPath": "string", + "fieldType": "string", + "fieldValue": "string", + "weight": "int64", + }, + ) + + @property + def keys(self) -> tuple[str, ...]: + return ("fieldValue", "fieldName", "fieldPath", "fieldType") + + @property + def aggs(self) -> dict[str, Any]: + return {"weight": lambda t: t.weight} + + def to_untagged(self): + from ._core import _get_merged_fields, _to_untagged + + all_roots = _find_all_root_models(self.source) + base_tbl = ( + _to_untagged(self.source).limit(self.sample) + if self.sample + else _to_untagged(self.source) + ) + + merged_dimensions = _get_merged_fields(all_roots, "dimensions") + fields_to_index = _get_fields_to_index( + self.selector, + merged_dimensions, + base_tbl, + ) + + if not fields_to_index: + from .._xorq import api as xo + + return xo.memtable( + { + "fieldName": [], + "fieldPath": [], + "fieldType": [], + "fieldValue": [], + "weight": [], + }, + ) + + def build_fragment(field_name: str) -> Any: + field_expr = ( + merged_dimensions[field_name](base_tbl) + if field_name in merged_dimensions + else base_tbl[field_name] + ) + field_type = field_expr.type() + type_str = _get_field_type_str(field_type) + weight_expr = _get_weight_expr( + base_tbl, + self.by, + all_roots, + field_type.is_string(), + ) + + return ( + _build_string_index_fragment( + base_tbl, + field_expr, + field_name, + field_name, + type_str, + weight_expr, + ) + if field_type.is_string() or not field_type.is_numeric() + else _build_numeric_index_fragment( + base_tbl, + field_expr, + field_name, + field_name, + type_str, + weight_expr, + ) + ) + + fragments = [build_fragment(f) for f in fields_to_index] + return reduce(lambda acc, frag: acc.union(frag), fragments[1:], fragments[0]) + + def filter(self, predicate: Callable) -> "SemanticFilter": + from ..expr import SemanticFilter + + return SemanticFilter(source=self, predicate=predicate) + + def order_by(self, *keys: str | ir.Value | Callable) -> "SemanticOrderBy": + from ..expr import SemanticOrderBy + + return SemanticOrderBy(source=self, keys=keys) + + def limit(self, n: int, offset: int = 0) -> "SemanticLimit": + from ..expr import SemanticLimit + + return SemanticLimit(source=self, n=n, offset=offset) + + def execute(self): + return _rebind_to_canonical_backend(self.to_untagged()).execute() + + def as_expr(self): + """Return self as expression.""" + return self + + def compile(self, **kwargs): + return self.to_untagged().compile(**kwargs) + + def sql(self, **kwargs): + return ibis.to_sql(self.to_untagged(), **kwargs) + + def __getitem__(self, key): + return self.to_untagged()[key] + + def pipe(self, func, *args, **kwargs): + return func(self, *args, **kwargs) diff --git a/src/boring_semantic_layer/ops/_measure_helpers.py b/src/boring_semantic_layer/ops/_measure_helpers.py new file mode 100644 index 0000000..f531ba1 --- /dev/null +++ b/src/boring_semantic_layer/ops/_measure_helpers.py @@ -0,0 +1,314 @@ +"""Measure-classification, JSON-definition, and column-error helpers. + +These analyze user-supplied measure expressions to decide whether they're +"base" (apply an aggregation to a column) or "calc" (reference other +measures), wrap them so the same pipeline can execute both, and produce +the human-readable diagnostic when a dimension lambda references a +column that doesn't exist on the table. +""" + +from __future__ import annotations + +import re +from collections.abc import Mapping +from difflib import get_close_matches +from typing import Any + +from ibis.expr import types as ir +from ibis.expr.operations.relations import Field +from returns.maybe import Maybe, Nothing, Some +from returns.result import Success, safe +from toolz import curry + +from ..measure_scope import ( + AggregationExpr, + AllOf, + BinOp, + ColumnScope, + MeasureRef, + MethodCall, +) +from ._callable import _CallableWrapper, _infer_unnest +from ._values import Measure, _is_deferred + + +def _extract_measure_metadata( + fn_or_expr: Any, +) -> tuple[Any, str | None, tuple, Mapping[str, Any]]: + """Extract metadata from various measure representations.""" + if isinstance(fn_or_expr, dict): + return ( + fn_or_expr["expr"], + fn_or_expr.get("description"), + tuple(fn_or_expr.get("requires_unnest", [])), + dict(fn_or_expr.get("metadata") or {}), + ) + elif isinstance(fn_or_expr, Measure): + return ( + fn_or_expr.expr, + fn_or_expr.description, + fn_or_expr.requires_unnest, + dict(fn_or_expr.metadata), + ) + else: + return (fn_or_expr, None, (), {}) + + +_AGG_METHODS = frozenset({"sum", "mean", "avg", "count", "min", "max"}) + + +def _is_calculated_measure(val: Any) -> bool: + # A MethodCall with an aggregation method on a MeasureRef is a base measure: + # the column name matched a known measure name in MeasureScope, but the user + # is really defining a column aggregation (e.g. lambda t: t.flight_count.sum()). + if ( + isinstance(val, MethodCall) + and val.method in _AGG_METHODS + and isinstance(val.receiver, MeasureRef) + ): + return False + return isinstance(val, MeasureRef | AllOf | BinOp | MethodCall | int | float) + + +def _matches_aggregation_pattern(measure_expr, agg_expr, tbl): + if not isinstance(agg_expr, AggregationExpr): + return Success(False) + + @curry + def evaluate_in_scope(tbl, expr): + """Evaluate measure expression in a ColumnScope.""" + scope = ColumnScope(_tbl=tbl) + return ( + expr.resolve(scope) if _is_deferred(expr) else expr(scope) if callable(expr) else expr + ) + + @curry + def has_matching_operation(agg_expr, result): + """Check if the operation matches the expected aggregation. + + All our supported aggregations (Sum, Mean, Count, Min, Max) are ibis operations. + """ + op_name = type(result.op()).__name__.lower() + expected_op = "avg" if agg_expr.operation.lower() == "mean" else agg_expr.operation.lower() + + return expected_op in op_name + + @curry + def has_matching_column(agg_expr, result): + """Check if result's operation references the expected column. + + All supported aggregation operations (Sum, Mean, Count, Min, Max) have: + - args[0]: Field operation with .name attribute + - args[1]: Optional where clause (typically None) + """ + op = result.op() + + if not isinstance(op.args[0], Field): + return False + + return op.args[0].name == agg_expr.column + + def matches_pattern(result): + """Check if result matches both operation and column.""" + return has_matching_operation(agg_expr, result) and has_matching_column(agg_expr, result) + + return safe(lambda: evaluate_in_scope(tbl, measure_expr))().map(matches_pattern) + + +def _find_matching_measure(agg_expr, known_measures: dict, tbl): + """Find a measure that matches the aggregation expression pattern. + + Returns Maybe[str] using functional patterns. + """ + if not isinstance(agg_expr, AggregationExpr): + return Nothing + + @curry + def matches_pattern(agg_expr, tbl, measure_obj): + """Check if measure matches the aggregation pattern. + + All measure_obj values are Measure instances with an expr attribute. + """ + result = _matches_aggregation_pattern(measure_obj.expr, agg_expr, tbl) + return result.value_or(False) + + for measure_name, measure_obj in known_measures.items(): + if matches_pattern(agg_expr, tbl, measure_obj): + return Some(measure_name) + + return Nothing + + +def _make_base_measure( + expr: Any, + description: str | None, + requires_unnest: tuple, + metadata: Mapping[str, Any] | None = None, +) -> Measure: + """Create a base measure with proper callable wrapping using functional patterns.""" + + @curry + def apply_aggregation(operation: str, column): + """Apply aggregation operation to a column using functional dispatch.""" + operations = { + "sum": lambda c: c.sum(), + "mean": lambda c: c.mean(), + "avg": lambda c: c.mean(), + "count": lambda c: c.count(), + "min": lambda c: c.min(), + "max": lambda c: c.max(), + } + + return ( + Maybe.from_optional(operations.get(operation)) + .map(lambda fn: fn(column)) + .value_or( + (_ for _ in ()).throw(ValueError(f"Unknown aggregation operation: {operation}")) + ) + ) + + @curry + def evaluate_expr(expr, scope): + """Evaluate expression in given scope.""" + return ( + expr.resolve(scope) if _is_deferred(expr) else expr(scope) if callable(expr) else expr + ) + + def convert_aggregation_expr(t, agg_expr: AggregationExpr): + """Convert AggregationExpr to ibis expression.""" + if agg_expr.operation == "count": + result = t.count() + else: + result = apply_aggregation(agg_expr.operation, t[agg_expr.column]) + + for method_name, args, kwargs_tuple in agg_expr.post_ops: + result = getattr(result, method_name)(*args, **dict(kwargs_tuple)) + + return result + + raw_expr = expr._fn if isinstance(expr, _CallableWrapper) else expr + + if isinstance(expr, AggregationExpr): + + def wrapped_expr(t): + """Convert AggregationExpr to ibis expression.""" + return convert_aggregation_expr(t, expr) + + return Measure( + expr=wrapped_expr, + description=description, + requires_unnest=requires_unnest, + original_expr=raw_expr, + metadata=dict(metadata or {}), + ) + + if callable(expr): + + def wrapped_expr(t): + """Wrapped expression that handles AggregationExpr conversion.""" + scope = ColumnScope(_tbl=t) + result = evaluate_expr(expr, scope) + + if isinstance(result, AggregationExpr): + return convert_aggregation_expr(t, result) + return result + + return Measure( + expr=wrapped_expr, + description=description, + requires_unnest=requires_unnest, + original_expr=raw_expr, + metadata=dict(metadata or {}), + ) + else: + return Measure( + expr=lambda t, fn=expr: evaluate_expr(fn, ColumnScope(_tbl=t)), + description=description, + requires_unnest=requires_unnest, + original_expr=raw_expr, + metadata=dict(metadata or {}), + ) + + +def _classify_measure( + fn_or_expr: Any, scope: Any, measure_name: str | None = None +) -> tuple[str, Any]: + """Classify measure as 'calc' or 'base' with appropriate handling.""" + from ..measure_scope import validate_calc_ast + from ._core import _resolve_expr + + expr, description, requires_unnest, metadata = _extract_measure_metadata(fn_or_expr) + + resolved = safe(lambda: _resolve_expr(expr, scope))().map( + lambda val: ("calc", val) if _is_calculated_measure(val) else None + ) + + if isinstance(resolved, Success) and resolved.unwrap() is not None: + kind, value = resolved.unwrap() + validate_calc_ast(value, measure_name) + return (kind, value) + + if not requires_unnest and callable(expr): + # All scopes (MeasureScope, ColumnScope) have tbl attribute + table = scope.tbl + inferred_unnest = _infer_unnest(expr, table) + requires_unnest = requires_unnest or inferred_unnest + + return ("base", _make_base_measure(expr, description, requires_unnest, metadata)) + + +def _build_json_definition( + dims_dict: dict, + meas_dict: dict, + name: str | None = None, + description: str | None = None, +) -> dict: + result = { + "dimensions": {n: spec.to_json() for n, spec in dims_dict.items()}, + "measures": {n: spec.to_json() for n, spec in meas_dict.items()}, + "entity_dimensions": {n: spec.to_json() for n, spec in dims_dict.items() if spec.is_entity}, + "event_timestamp": { + n: spec.to_json() for n, spec in dims_dict.items() if spec.is_event_timestamp + }, + "time_dimensions": { + n: spec.to_json() for n, spec in dims_dict.items() if spec.is_time_dimension + }, + "name": name, + } + if description is not None: + result["description"] = description + return result + + +def _format_column_error(e: AttributeError, table: ir.Table) -> str: + """Format a helpful error message for missing column errors.""" + # Extract the column name from the error + match = re.search(r"has no attribute ['\"]([^'\"]+)['\"]", str(e)) + missing_col = match.group(1) if match else "unknown" + + # Get available columns + available_cols = list(table.columns) if hasattr(table, "columns") else [] + + # Build error message + parts = [f"Dimension expression references non-existent column '{missing_col}'."] + + if len(available_cols) > 20: + parts.append(f"Table has {len(available_cols)} columns. First 15: {available_cols[:15]}") + elif available_cols: + parts.append(f"Available columns: {available_cols}") + else: + parts.append(f"No columns available in {type(table).__name__} object") + + # Suggest similar column names + suggestions = get_close_matches(missing_col, available_cols, n=3, cutoff=0.6) + if suggestions: + parts[-1] += f". Did you mean: {suggestions}?" + + # Add helpful tip + example = suggestions[0] if suggestions else "column_name" + parts.append( + f"\n\nTip: Check that your dimension expression uses the correct column name. " + f"For example: lambda t: t.{example}" + ) + + return " ".join(parts) diff --git a/src/boring_semantic_layer/ops/_mutate_unnest.py b/src/boring_semantic_layer/ops/_mutate_unnest.py new file mode 100644 index 0000000..7f87576 --- /dev/null +++ b/src/boring_semantic_layer/ops/_mutate_unnest.py @@ -0,0 +1,161 @@ +"""``SemanticMutateOp`` and ``SemanticUnnestOp``. + +Mutate adds derived columns to an aggregated/projected table; Unnest +explodes an array column into one row per element. Both pass through +dimension/measure metadata to their source. +""" + +from __future__ import annotations + +from collections.abc import Callable, Mapping +from typing import Any + +from ibis.expr import datatypes as dt +from ibis.expr.operations.relations import Relation +from ibis.expr.schema import Schema + +from .._xorq import FrozenDict, FrozenOrderedDict +from ..measure_scope import MeasureScope +from ._callable import _ensure_wrapped +from ._format import _semantic_repr +from ._values import Dimension, Measure + + +class SemanticMutateOp(Relation): + source: Relation + post: dict[ + str, + Callable, + ] # Transformed to FrozenDict[str, _CallableWrapper] in __init__ + nested_columns: tuple[ + str, + ..., + ] = () # Inherited from source if it has nested columns + + def __init__( + self, + source: Relation, + post: dict[str, Callable] | None, + nested_columns: tuple[str, ...] = (), + ) -> None: + frozen_post = FrozenDict( + {name: _ensure_wrapped(fn) for name, fn in (post or {}).items()}, + ) + source_nested = nested_columns if nested_columns else getattr(source, "nested_columns", ()) + + super().__init__( + source=Relation.__coerce__(source), + post=frozen_post, + nested_columns=source_nested, + ) + + def __repr__(self) -> str: + return _semantic_repr(self) + + @property + def values(self) -> FrozenOrderedDict[str, Any]: + return self.source.values + + @property + def schema(self) -> Schema: + return self.source.schema + + def to_untagged(self): + from ._core import _resolve_expr, _to_untagged, _unwrap + + agg_tbl = _to_untagged(self.source) + + # Process mutations incrementally so each can reference previous ones + # This allows: .mutate(rank=..., is_other=lambda t: t["rank"] > 5) + current_tbl = agg_tbl + for name, fn_wrapped in self.post.items(): + proxy = MeasureScope(_tbl=current_tbl, _known=[], _post_agg=True) + resolved = _resolve_expr(_unwrap(fn_wrapped), proxy) + + new_col = resolved.name(name) + current_tbl = current_tbl.mutate([new_col]) + + return current_tbl + + def get_dimensions(self) -> Mapping[str, Dimension]: + """Get dictionary of dimensions from source.""" + return self.source.get_dimensions() + + def get_measures(self) -> Mapping[str, Measure]: + """Get dictionary of measures from source.""" + return self.source.get_measures() + + def get_calculated_measures(self) -> Mapping[str, Any]: + """Get dictionary of calculated measures from source.""" + return self.source.get_calculated_measures() + + +class SemanticUnnestOp(Relation): + """Unnest an array column, expanding rows (like Malloy's nested data pattern).""" + + source: Relation + column: str + + def __repr__(self) -> str: + return _semantic_repr(self) + + @property + def schema(self) -> Schema: + # After unnesting, the schema changes - the array column is replaced by its element schema + # For now, delegate to source schema (ideally we'd update it) + return self.source.schema + + @property + def values(self) -> FrozenDict: + return FrozenDict({}) + + def to_untagged(self): + """Convert to Ibis expression with functional struct unpacking. + + Uses pure helper functions to extract struct fields when unnesting + produces struct columns that need to be expanded. + """ + from ._core import _to_untagged + + def build_struct_fields(col_expr, col_type): + """Pure function: build dict of struct field selections.""" + return {name: col_expr[name] for name in col_type.names} + + def unpack_struct_if_needed(unnested_tbl, column_name): + """Conditionally unpack struct fields into top-level columns.""" + if column_name not in unnested_tbl.columns: + return unnested_tbl + + col_expr = unnested_tbl[column_name] + col_type = col_expr.type() + + # Only Struct types have fields to unpack + if isinstance(col_type, dt.Struct) and col_type.fields: + struct_fields = build_struct_fields(col_expr, col_type) + return unnested_tbl.select(unnested_tbl, **struct_fields) + + return unnested_tbl + + tbl = _to_untagged(self.source) + + if self.column not in tbl.columns: + raise ValueError(f"Column '{self.column}' not found in table") + + try: + unnested = tbl.unnest(self.column) + except Exception as e: + raise ValueError(f"Failed to unnest column '{self.column}': {e}") from e + + return unpack_struct_if_needed(unnested, self.column) + + def get_dimensions(self) -> Mapping[str, Dimension]: + """Get dictionary of dimensions from source.""" + return self.source.get_dimensions() + + def get_measures(self) -> Mapping[str, Measure]: + """Get dictionary of measures from source.""" + return self.source.get_measures() + + def get_calculated_measures(self) -> Mapping[str, Any]: + """Get dictionary of calculated measures from source.""" + return self.source.get_calculated_measures() diff --git a/src/boring_semantic_layer/ops/_normalize.py b/src/boring_semantic_layer/ops/_normalize.py new file mode 100644 index 0000000..9511bb4 --- /dev/null +++ b/src/boring_semantic_layer/ops/_normalize.py @@ -0,0 +1,105 @@ +"""Argument normalizers for join predicates and dimension/measure names. + +Used by the user-facing ``SemanticTable`` API to accept strings, simple +``Deferred`` expressions (``_.col``), and callable predicates uniformly. +""" + +from __future__ import annotations + +from ibis.common.deferred import Deferred + +from ._values import _is_deferred + + +def _normalize_to_name(arg: str | Deferred) -> str: + """Convert a string or simple ``_.name`` Deferred to a plain string name. + + Accepts a plain string (returned as-is) or a Deferred whose resolver is a + simple attribute access on the top-level ``_`` variable (e.g. ``_.origin``). + + Complex expressions like ``_.distance.sum()`` or ``_.a.b`` are rejected + with a ``TypeError``. + """ + if isinstance(arg, str): + return arg + + # Duck-type: works for both ibis and xorq Deferred objects + resolver = getattr(arg, "_resolver", None) + if resolver is None: + raise TypeError( + f"Expected a string name or Deferred expression (_.name), got {type(arg).__name__}" + ) + + obj = getattr(resolver, "obj", None) + + # Try attribute access first (_.name -> Attr resolver with .name) + name_wrapper = getattr(resolver, "name", None) + + # Fall back to getitem access (_["name"] -> Item resolver with .indexer) + if name_wrapper is None: + name_wrapper = getattr(resolver, "indexer", None) + + if name_wrapper is None or obj is None: + raise TypeError( + f"Only simple Deferred expressions like _.name or _['name'] are supported " + f"as positional arguments, got: {arg!r}" + ) + + # Reject chained access like _.a.b (obj would itself have an .obj attr) + if getattr(obj, "obj", None) is not None: + raise TypeError( + f"Only simple Deferred expressions like _.name or _['name'] are supported " + f"as positional arguments, got: {arg!r}" + ) + + # Attr.name / Item.indexer is a Just wrapper; unwrap via .value + raw_name = getattr(name_wrapper, "value", name_wrapper) + if not isinstance(raw_name, str): + raise TypeError(f"Could not extract string name from Deferred expression: {arg!r}") + + return raw_name + + +def _normalize_join_predicate(on): + """Normalize a join predicate to a two-argument callable. + + Accepts: + - ``str`` – equi-join on a column present in both sides + - ``Deferred`` (``_.col``) – same, after extracting the name + - ``list[str | Deferred]`` – compound equi-join on multiple columns + - ``callable`` (non-Deferred) – returned as-is (existing lambda API) + - ``None`` – returned as-is (for cross joins) + """ + if on is None: + return on + + if isinstance(on, str): + name = on + return lambda left, right: getattr(left, name) == getattr(right, name) + + if _is_deferred(on): + name = _normalize_to_name(on) + return lambda left, right: getattr(left, name) == getattr(right, name) + + if isinstance(on, (list, tuple)): + names = [_normalize_to_name(item) for item in on] + if len(names) == 1: + name = names[0] + return lambda left, right: getattr(left, name) == getattr(right, name) + + def _compound_predicate(left, right): + from functools import reduce + from operator import and_ + + preds = [getattr(left, n) == getattr(right, n) for n in names] + return reduce(and_, preds) + + return _compound_predicate + + if callable(on): + return on + + raise TypeError( + f"join `on` must be a string, Deferred (_.col), list of strings/Deferred, " + f"or a callable, got {type(on).__name__}" + ) diff --git a/src/boring_semantic_layer/ops/_order_limit.py b/src/boring_semantic_layer/ops/_order_limit.py new file mode 100644 index 0000000..8bda048 --- /dev/null +++ b/src/boring_semantic_layer/ops/_order_limit.py @@ -0,0 +1,118 @@ +"""``SemanticOrderByOp`` and ``SemanticLimitOp`` — terminal pass-through ops. + +Both are thin wrappers that delegate dimension/measure metadata to their +source and only modify ``to_untagged()`` output. They share the same +shape: pass-through ``schema``/``values``, override ``__repr__`` and +``to_untagged``, and forward ``get_*`` accessors to ``self.source``. +""" + +from __future__ import annotations + +from collections.abc import Callable, Iterable, Mapping +from typing import Any + +from ibis.expr import types as ir +from ibis.expr.operations.relations import Relation +from ibis.expr.schema import Schema + +from .._xorq import FrozenOrderedDict +from ._callable import _CallableWrapper, _ensure_wrapped +from ._format import _semantic_repr +from ._values import Dimension, Measure + + +class SemanticOrderByOp(Relation): + source: Relation + keys: tuple[ + str | ir.Value | Callable, + ..., + ] # Transformed to tuple[str | _CallableWrapper, ...] in __init__ + + def __init__(self, source: Relation, keys: Iterable[str | ir.Value | Callable]) -> None: + def wrap_key(k): + return k if isinstance(k, str | _CallableWrapper) else _ensure_wrapped(k) + + super().__init__( + source=Relation.__coerce__(source), + keys=tuple(wrap_key(k) for k in keys), + ) + + def __repr__(self) -> str: + return _semantic_repr(self) + + @property + def values(self) -> FrozenOrderedDict[str, Any]: + return self.source.values + + @property + def schema(self) -> Schema: + return self.source.schema + + def to_untagged(self): + from ._core import _resolve_expr, _to_untagged, _unwrap + + tbl = _to_untagged(self.source) + + def resolve_order_key(key): + if isinstance(key, str): + return tbl[key] if key in tbl.columns else getattr(tbl, key, key) + elif isinstance(key, _CallableWrapper): + unwrapped = _unwrap(key) + return _resolve_expr(unwrapped, tbl) + return key + + return tbl.order_by([resolve_order_key(key) for key in self.keys]) + + def get_dimensions(self) -> Mapping[str, Dimension]: + """Get dictionary of dimensions from source.""" + return self.source.get_dimensions() + + def get_measures(self) -> Mapping[str, Measure]: + """Get dictionary of measures from source.""" + return self.source.get_measures() + + def get_calculated_measures(self) -> Mapping[str, Any]: + """Get dictionary of calculated measures from source.""" + return self.source.get_calculated_measures() + + +class SemanticLimitOp(Relation): + source: Relation + n: int + offset: int + + def __init__(self, source: Relation, n: int, offset: int = 0) -> None: + if n <= 0: + raise ValueError(f"limit must be positive, got {n}") + if offset < 0: + raise ValueError(f"offset must be non-negative, got {offset}") + super().__init__(source=Relation.__coerce__(source), n=n, offset=offset) + + def __repr__(self) -> str: + return _semantic_repr(self) + + @property + def values(self) -> FrozenOrderedDict[str, Any]: + return self.source.values + + @property + def schema(self) -> Schema: + return self.source.schema + + def to_untagged(self): + from ._core import _to_untagged + + tbl = _to_untagged(self.source) + return tbl.limit(self.n) if self.offset == 0 else tbl.limit(self.n, offset=self.offset) + + def get_dimensions(self) -> Mapping[str, Dimension]: + """Get dictionary of dimensions from source.""" + return self.source.get_dimensions() + + def get_measures(self) -> Mapping[str, Measure]: + """Get dictionary of measures from source.""" + return self.source.get_measures() + + def get_calculated_measures(self) -> Mapping[str, Any]: + """Get dictionary of calculated measures from source.""" + return self.source.get_calculated_measures() diff --git a/src/boring_semantic_layer/ops/_root_models.py b/src/boring_semantic_layer/ops/_root_models.py new file mode 100644 index 0000000..f88644f --- /dev/null +++ b/src/boring_semantic_layer/ops/_root_models.py @@ -0,0 +1,483 @@ +"""Root-model traversal and join-aware field merging. + +These helpers walk a relation tree to find the underlying ``SemanticTableOp`` +roots (one per leaf table) and merge their dimensions/measures with proper +table prefixing. They also handle the ibis quirk of renaming non-key columns +on the right side of a join with ``_right`` / ``_right2`` / … suffixes. + +``SemanticTableOp`` and ``SemanticJoinOp`` are imported lazily inside each +function to avoid a circular module dependency with ``_core.py``. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +from ibis.expr import types as ir +from ibis.expr.operations.relations import Relation + +from .._xorq import FrozenDict +from ._column_extraction import _extract_columns_from_callable, _extract_join_key_columns +from ._values import Dimension + + +def _find_root_model(node: Any): + """Find root SemanticTableOp in the operation tree.""" + from ._core import SemanticTableOp + + cur = node + while cur is not None: + if isinstance(cur, SemanticTableOp): + return cur + parent = getattr(cur, "source", None) + cur = parent + return None + + +def _find_all_root_models(node: Any) -> tuple[Any, ...]: + """Find all root SemanticTableOps in the operation tree (handles joins with multiple roots).""" + from ._core import SemanticTableOp + + if isinstance(node, SemanticTableOp): + return [node] + + roots = [] + + if hasattr(node, "left") and hasattr(node, "right"): + roots.extend(_find_all_root_models(node.left)) + roots.extend(_find_all_root_models(node.right)) + elif hasattr(node, "source") and node.source is not None: + roots.extend(_find_all_root_models(node.source)) + + return roots + + +def _dimension_only_source_table( + keys: tuple[str, ...], + all_roots: Sequence[Any], + filters: tuple, +): + """Check if a dimension-only query can be routed to a single source table. + + When all requested dimension keys share a single table prefix and that + prefix maps to a root model whose dimensions cover every key, we can + bypass the join and query the dimension table directly. This ensures + dimension members with no matching fact rows are still returned. + + *filters* are the ``_CallableWrapper`` predicates collected between the + aggregate and the underlying join. Filters whose column references all + belong to the target table are forwarded; if any filter references columns + outside the target table the shortcut is disabled. + + Returns ``(root_op, unprefixed_keys, applicable_filters)`` or ``None``. + """ + from ._core import _to_untagged, _unwrap + + if not keys: + return None + + prefixes: set[str] = set() + unprefixed: list[str] = [] + for key in keys: + if "." not in key: + return None # Non-prefixed key — can't determine source + prefix, name = key.split(".", 1) + prefixes.add(prefix) + unprefixed.append(name) + + if len(prefixes) != 1: + return None # Keys span multiple tables + + target_prefix = next(iter(prefixes)) + + for root in all_roots: + if root.name == target_prefix: + root_dims = root.get_dimensions() + if all(k in root_dims for k in unprefixed): + # Validate that every filter only touches columns present + # on the target dimension table. If any filter references + # columns from other tables we cannot use the shortcut. + if filters: + tbl = _to_untagged(root) + tbl_cols = frozenset(tbl.columns) | frozenset(root_dims) + for flt in filters: + fn = _unwrap(flt) if hasattr(flt, "unwrap") else flt + extraction = _extract_columns_from_callable(fn, tbl) + if extraction.extraction_failed: + return None # Can't determine — bail out + if not extraction.columns <= tbl_cols: + return None # References columns outside target + return root, unprefixed, filters + + return None + + +def _build_join_depth_map(node: Any) -> dict[str, int]: + """Map each leaf table name to its actual ibis rname depth. + + ``SemanticJoinOp.to_untagged`` calls ``_join_depth`` to determine the + rname suffix for each join level. ``_join_depth`` counts the number + of ``SemanticJoinOp`` ancestors on the *left* spine. The right child + at depth *d* gets ``rname = _rname_for_depth(d)``. + + For nested subtrees on the right side of a join, ibis applies the + inner subtree's rname independently. So ``aircraft_models`` at inner + depth 1 gets ``_right``, not ``_right3`` even if the outer depth is 3. + + This function mirrors ``_join_depth`` logic: walk down the left spine, + recording the right child's depth at each level. If the right child is + itself a join tree, recurse to get inner depths for its leaves. + """ + from ._core import SemanticJoinOp, SemanticTableOp + + depth_map: dict[str, int] = {} + + def _record_leaf(n, depth: int): + """Record a leaf table at the given depth.""" + if isinstance(n, SemanticTableOp): + name = n.name + if name and name not in depth_map: + depth_map[name] = depth + + def _walk_join_spine(n): + """Walk the left spine of a join tree, recording depths.""" + if not isinstance(n, SemanticJoinOp): + # Leftmost leaf: depth 0 (root, never renamed) + _record_leaf(n, 0) + return + + depth = SemanticJoinOp._join_depth(n) + # The right child is joined at this depth + right = n.right + if isinstance(right, SemanticJoinOp): + # Right is a subtree — its leaves get inner depths + inner_map = _build_join_depth_map(right) + for tname, idepth in inner_map.items(): + if tname not in depth_map: + if idepth == 0: + # Leftmost leaf of subtree sits at the outer depth + # (it receives the outer rname suffix if conflicting) + depth_map[tname] = depth + else: + # Inner leaves keep their inner depth (inner rname) + depth_map[tname] = idepth + else: + _record_leaf(right, depth) + + # Recurse down the left spine + _walk_join_spine(n.left) + + _walk_join_spine(node) + return depth_map + + +def _update_measure_refs_in_calc(expr, prefix_map: dict[str, str]): + """ + Recursively update MeasureRef names in a calculated measure expression. + + Args: + expr: A MeasureExpr (MeasureRef, AllOf, BinOp, MethodCall, or literal) + prefix_map: Mapping from old name to new prefixed name + + Returns: + Updated expression with prefixed MeasureRef names + """ + from ..measure_scope import AllOf, BinOp, MeasureRef, MethodCall + + if isinstance(expr, MeasureRef): + # Update the measure reference name if it's in the map + new_name = prefix_map.get(expr.name, expr.name) + return MeasureRef(new_name) + elif isinstance(expr, AllOf): + # Update the inner MeasureRef + updated_ref = _update_measure_refs_in_calc(expr.ref, prefix_map) + return AllOf(updated_ref) + elif isinstance(expr, MethodCall): + updated_receiver = _update_measure_refs_in_calc(expr.receiver, prefix_map) + return MethodCall( + receiver=updated_receiver, + method=expr.method, + args=expr.args, + kwargs=expr.kwargs, + ) + elif isinstance(expr, BinOp): + # Recursively update left and right + updated_left = _update_measure_refs_in_calc(expr.left, prefix_map) + updated_right = _update_measure_refs_in_calc(expr.right, prefix_map) + return BinOp(op=expr.op, left=updated_left, right=updated_right) + else: + # Literal number or other - return as-is + return expr + + +def _extract_join_key_column_names(source: Relation) -> set[str]: + """ + Extract column names that ibis will merge (coalesce) during joins. + + Ibis only merges join-key columns when **both** sides of an equi-join share + the **same** column name (e.g., ``l.code == r.code``). When names differ + (e.g., ``l.carrier == r.code``), the right column gets a ``_right`` suffix + instead. We return only the intersection of left/right key names so that + ``_check_and_add_rename`` correctly detects columns that need renaming. + + Args: + source: The relation to search for join operations + + Returns: + Set of column names that ibis merges (same-name equi-join keys) + """ + from ._core import SemanticJoinOp + + join_keys: set[str] = set() + + def find_joins(node): + """Recursively find join operations and extract merged key columns.""" + if isinstance(node, SemanticJoinOp) and node.on: + try: + left_expr = node.left.to_expr() if hasattr(node.left, "to_expr") else node.left + right_expr = node.right.to_expr() if hasattr(node.right, "to_expr") else node.right + result = _extract_join_key_columns(node.on, left_expr, right_expr) + if result.is_success(): + # ibis merges only same-name equi-join columns + join_keys.update(result.left_columns & result.right_columns) + except (AttributeError, TypeError): + pass + + if hasattr(node, "left") and isinstance(node.left, Relation): + find_joins(node.left) + if hasattr(node, "right") and isinstance(node.right, Relation): + find_joins(node.right) + if hasattr(node, "source") and isinstance(node.source, Relation): + find_joins(node.source) + + find_joins(source) + return join_keys + + +def _build_column_rename_map( + all_roots: Sequence[Any], + field_accessor: callable, + source: Relation | None = None, +) -> dict[str, str]: + """ + Build a mapping of dimension names to their renamed column names in joined tables. + + When Ibis joins tables with duplicate column names, it renames columns from later + tables with '_right' suffix. However, columns used as join keys are merged and + NOT renamed, so we exclude them from the rename map. + + Uses graph_utils for generic traversal and the returns library for safe handling. + + Args: + all_roots: List of root semantic tables in join order + field_accessor: Function to get fields (dimensions) from a root + source: Optional source relation to extract join keys from + + Returns: + Dict mapping dimension names like 'airports.city' to renamed columns like 'city_right' + """ + # Build column index using graph_utils (returns Result) + from returns.result import Failure + + from ..graph_utils import build_column_index_from_roots, extract_column_from_dimension + + column_index_result = build_column_index_from_roots(all_roots) + if isinstance(column_index_result, Failure): + # If we can't build the index, return empty map (dimensions will use fallback behavior) + return {} + + column_index = column_index_result.value_or({}) + + # Extract join key columns to exclude from renaming + join_keys = _extract_join_key_column_names(source) if source else set() + + # Build a map from table name → actual ibis join depth by walking the + # join tree. The flat index in all_roots does NOT equal ibis join depth + # for nested joins (e.g. aircraft → aircraft_models inside a flights + # join tree), so we must compute it from the tree structure. + join_depth_map: dict[str, int] = {} + if source is not None: + join_depth_map = _build_join_depth_map(source) + + # Process dimensions and determine which need renamed columns + rename_map = {} + + for idx, root in enumerate(all_roots): + if not root.name: + continue + + fields_dict = field_accessor(root) + if not fields_dict: + continue + + root_tbl = root.to_untagged() + # Use the actual join depth if available, otherwise fall back to table_idx + effective_depth = join_depth_map.get(root.name, idx) + + for field_name, field_value in fields_dict.items(): + # Extract column name using graph_utils (returns Maybe) + column_maybe = extract_column_from_dimension(field_value, root_tbl) + + # Use Maybe pattern from returns library + column_maybe.bind_optional( + lambda base_column: _check_and_add_rename( # noqa: B023 + rename_map=rename_map, + base_column=base_column, + prefixed_name=f"{root.name}.{field_name}", # noqa: B023 + table_idx=idx, # noqa: B023 + column_index=column_index, + join_keys=join_keys, + join_depth=effective_depth, # noqa: B023 + ) + ) + + return rename_map + + +def _check_and_add_rename( + rename_map: dict[str, str], + base_column: str, + prefixed_name: str, + table_idx: int, + column_index: dict[str, list[int]], + join_keys: set[str], + join_depth: int | None = None, +) -> None: + """ + Check if a column needs renaming and add to rename map if so. + + ``table_idx`` is the flat index in ``all_roots`` used to detect + whether an earlier table has the same column. ``join_depth`` is + the actual ibis join depth (from ``_build_join_depth_map``) used + to compute the ``_right`` / ``_right2`` / … suffix. + + Args: + rename_map: Map to update with renames + base_column: The base column name + prefixed_name: The prefixed dimension name (e.g., 'airports.city') + table_idx: Flat index in all_roots (for conflict detection) + column_index: Index of column occurrences + join_keys: Set of column names used as join keys (these don't get renamed) + join_depth: Actual ibis join depth for suffix computation (defaults to table_idx) + """ + # Skip columns that are join keys - they get merged, not renamed + if base_column in join_keys: + return + + depth = join_depth if join_depth is not None else table_idx + + if base_column in column_index: + tables_with_column = column_index[base_column] + # Check if any table before this one (in flat order) has the same column + earlier_tables = [t for t in tables_with_column if t < table_idx] + if earlier_tables: + suffix = "_right" if depth <= 1 else f"_right{depth}" + rename_map[prefixed_name] = f"{base_column}{suffix}" + + +def _wrap_dimension_for_renamed_column(dimension: Dimension, renamed_column: str) -> Dimension: + """ + Wrap a dimension to access a renamed column in a joined table. + + Args: + dimension: The original dimension + renamed_column: The renamed column name (e.g., 'city_right') + + Returns: + A new Dimension that accesses the renamed column + """ + + # Create a new callable that accesses the renamed column + def renamed_accessor(table: ir.Table) -> ir.Value: + return table[renamed_column] + + # Return a new Dimension with the wrapped callable but same metadata + return Dimension( + expr=renamed_accessor, + description=dimension.description, + is_entity=dimension.is_entity, + is_time_dimension=dimension.is_time_dimension, + is_event_timestamp=dimension.is_event_timestamp, + smallest_time_grain=dimension.smallest_time_grain, + derived_dimensions=dimension.derived_dimensions, + ) + + +def _merge_fields_with_prefixing( + all_roots: Sequence[Any], + field_accessor: callable, + source: Relation | None = None, +) -> FrozenDict[str, Any]: + """ + Generic function to merge any type of fields (dimensions, measures) with prefixing. + + Args: + all_roots: List of SemanticTable root models + field_accessor: Function that takes a root and returns the fields dict (e.g. lambda r: r.dimensions) + source: Optional source relation to extract join keys from for proper column renaming + + Returns: + FrozenDict mapping field names (always prefixed with table name) to field values + """ + if not all_roots: + return FrozenDict() + + merged_fields = {} + + is_calc_measures = False + is_dimensions = False + if all_roots: + sample_fields = field_accessor(all_roots[0]) + if sample_fields: + from ..measure_scope import AllOf, BinOp, MeasureRef, MethodCall + + first_val = next(iter(sample_fields.values()), None) + is_calc_measures = isinstance( + first_val, + MeasureRef | AllOf | BinOp | MethodCall | int | float, + ) + is_dimensions = isinstance(first_val, Dimension) + + # For dimensions, build a column rename map to handle Ibis join conflicts + column_rename_map = {} + if is_dimensions: + column_rename_map = _build_column_rename_map(all_roots, field_accessor, source) + + for root in all_roots: + root_name = root.name + fields_dict = field_accessor(root) + + if is_calc_measures and root_name: + base_map = ( + {k: f"{root_name}.{k}" for k in root.get_measures()} + if hasattr(root, "get_measures") + else {} + ) + calc_map = ( + {k: f"{root_name}.{k}" for k in root.get_calculated_measures()} + if hasattr(root, "get_calculated_measures") + else {} + ) + prefix_map = {**base_map, **calc_map} + + for field_name, field_value in fields_dict.items(): + if root_name: + # Always use prefixed name with . separator + prefixed_name = f"{root_name}.{field_name}" + + # If it's a calculated measure, update internal MeasureRefs + if is_calc_measures: + field_value = _update_measure_refs_in_calc(field_value, prefix_map) + # If it's a dimension that needs column renaming, wrap the callable + elif is_dimensions and prefixed_name in column_rename_map: + field_value = _wrap_dimension_for_renamed_column( + field_value, column_rename_map[prefixed_name] + ) + + merged_fields[prefixed_name] = field_value + else: + # Fallback to original name if no root name + merged_fields[field_name] = field_value + + return FrozenDict(merged_fields) diff --git a/src/boring_semantic_layer/ops/_values.py b/src/boring_semantic_layer/ops/_values.py new file mode 100644 index 0000000..74fcb12 --- /dev/null +++ b/src/boring_semantic_layer/ops/_values.py @@ -0,0 +1,166 @@ +"""Public value objects: ``Dimension`` and ``Measure``. + +Plus the prefix proxies that let dimension lambdas use model-prefix +navigation (``lambda t: t.flights.carrier``) on joined tables. +""" + +from __future__ import annotations + +from collections.abc import Callable, Mapping +from typing import Any + +from attrs import field, frozen +from ibis.common.deferred import Deferred +from ibis.expr import types as ir + + +def _is_deferred(expr) -> bool: + # Duck-type check: works for both ibis and xorq Deferred objects + return hasattr(expr, "_resolver") and hasattr(expr, "resolve") + + +class _DimPrefixProxy: + """Resolves ``proxy.column`` to ``dims["prefix.column"](table)``.""" + + __slots__ = ("_tbl", "_dims", "_prefix") + + def __init__(self, tbl, dims: dict, prefix: str): + object.__setattr__(self, "_tbl", tbl) + object.__setattr__(self, "_dims", dims) + object.__setattr__(self, "_prefix", prefix) + + def __getattr__(self, name: str): + full_name = f"{self._prefix}.{name}" + if full_name in self._dims: + return self._dims[full_name](self._tbl) + raise AttributeError( + f"No dimension '{full_name}' found. " + f"Available dimensions with prefix '{self._prefix}.': " + f"{[k for k in self._dims if k.startswith(self._prefix + '.')]}" + ) + + +class _DimensionTableProxy: + """Proxy that wraps an ibis table to support model-prefix navigation. + + Allows dimension lambdas like ``lambda t: t.flights.carrier`` to work on + joined tables by resolving ``t.flights.carrier`` through the merged + dimension map (``dims["flights.carrier"](table)``). + """ + + __slots__ = ("_tbl", "_dims") + + def __init__(self, tbl, dims: dict): + object.__setattr__(self, "_tbl", tbl) + object.__setattr__(self, "_dims", dims) + + def __getattr__(self, name: str): + prefix = f"{name}." + if any(k.startswith(prefix) for k in self._dims): + return _DimPrefixProxy(self._tbl, self._dims, name) + return getattr(self._tbl, name) + + def __getitem__(self, name: str): + if name in self._dims: + return self._dims[name](self._tbl) + return self._tbl[name] + + @property + def columns(self): + return self._tbl.columns + + +@frozen(kw_only=True, slots=True) +class Dimension: + expr: Callable[[ir.Table], ir.Value] | Deferred + description: str | None = None + is_entity: bool = False + is_time_dimension: bool = False + is_event_timestamp: bool = False + smallest_time_grain: str | None = None + derived_dimensions: tuple[str, ...] = () + metadata: Mapping[str, Any] = field(factory=dict, eq=False, hash=False) + + def __call__(self, table: ir.Table, _dims: dict | None = None) -> ir.Value: + try: + return self.expr.resolve(table) if _is_deferred(self.expr) else self.expr(table) + except AttributeError as e: + # Retry with a prefix-aware proxy for joined tables where + # model prefixes are used (e.g., lambda t: t.flights.carrier) + if _dims and not _is_deferred(self.expr) and callable(self.expr): + try: + proxy = _DimensionTableProxy(table, _dims) + return self.expr(proxy) + except AttributeError as proxy_err: + # Preserve explicit prefix-proxy errors (e.g. missing + # "model.field") to avoid silent fallback to unprefixed + # columns, but keep normal missing-column errors on the + # original table so they get the helpful formatter below. + if str(proxy_err).startswith("No dimension '"): + raise + except Exception: + pass + # Provide helpful error for missing columns + if "'Table' object has no attribute" in str( + e + ) or "'Join' object has no attribute" in str(e): + from ._measure_helpers import _format_column_error + + raise AttributeError(_format_column_error(e, table)) from e + raise + + def to_json(self) -> Mapping[str, Any]: + base = {"description": self.description} + if self.is_entity: + base["is_entity"] = True + if self.is_event_timestamp: + base["is_event_timestamp"] = True + if self.is_time_dimension: + base["smallest_time_grain"] = self.smallest_time_grain + if self.derived_dimensions: + base["derived_dimensions"] = list(self.derived_dimensions) + if self.metadata: + base.update(self.metadata) + return base + + def __hash__(self) -> int: + return hash( + ( + self.description, + self.is_entity, + self.is_event_timestamp, + self.is_time_dimension, + self.smallest_time_grain, + self.derived_dimensions, + ), + ) + + +@frozen(kw_only=True, slots=True) +class Measure: + expr: Callable[[ir.Table], ir.Value] | Deferred + description: str | None = None + requires_unnest: tuple[str, ...] = () # Internal: Arrays that must be unnested + original_expr: Any = field(default=None, eq=False, hash=False) + metadata: Mapping[str, Any] = field(factory=dict, eq=False, hash=False) + + def __call__(self, table: ir.Table) -> ir.Value: + return self.expr.resolve(table) if _is_deferred(self.expr) else self.expr(table) + + @property + def locality(self) -> str | None: + """Derive locality from requires_unnest (most nested level).""" + return self.requires_unnest[-1] if self.requires_unnest else None + + def to_json(self) -> Mapping[str, Any]: + base = {"description": self.description} + if self.locality: + base["locality"] = self.locality + if self.requires_unnest: + base["requires_unnest"] = list(self.requires_unnest) + if self.metadata: + base.update(self.metadata) + return base + + def __hash__(self) -> int: + return hash((self.description, self.requires_unnest)) diff --git a/src/boring_semantic_layer/ops/_xorq_compat.py b/src/boring_semantic_layer/ops/_xorq_compat.py new file mode 100644 index 0000000..f4e9fbf --- /dev/null +++ b/src/boring_semantic_layer/ops/_xorq_compat.py @@ -0,0 +1,127 @@ +"""xorq/ibis compatibility shims. + +These functions bridge plain ibis and xorq's vendored ibis: convert one +to the other (``_ensure_xorq_table``), patch SortKey shape differences +(``_patch_xorq_sortkey_compat``), and rebind ``DatabaseTable`` nodes so +expressions composed from separately-converted tables share a single +backend (``_rebind_to_backend`` / ``_rebind_to_canonical_backend``). +""" + +from __future__ import annotations + +import logging + +logger = logging.getLogger(__name__) + + +def _patch_xorq_sortkey_compat(): + """Register a map_ibis handler so ibis SortKey → xorq SortKey. + + ibis 11 uses ``SortKey.expr``, ibis 12 renamed it to ``SortKey.arg``, + while xorq's vendored ibis keeps ``SortKey.expr``. Handle both. + """ + from ibis.expr.operations.sortkeys import SortKey as IbisSortKey + + from .._xorq import SortKey as XorqSortKey, map_ibis + + if IbisSortKey in map_ibis.registry: + return # already patched + + @map_ibis.register(IbisSortKey) + def _map_sort_key(val, kwargs=None): + # ibis 12 uses .arg, ibis 11 uses .expr + sort_expr = getattr(val, "arg", None) or getattr(val, "expr") + return XorqSortKey( + expr=map_ibis(sort_expr, None), + ascending=val.ascending, + nulls_first=val.nulls_first, + ) + + +def _ensure_xorq_table(table): + """Convert plain ibis Table to xorq-vendored ibis if possible. + + This is the single boundary between user-supplied ibis tables and + BSL's internal xorq representation. ``SemanticModel`` calls it once + at construction so internal code paths can assume xorq tables when + the backend is supported, and a plain ibis fallback otherwise. + + Falls back to returning the plain ibis table when the backend is not + supported by xorq (e.g. Databricks). Idempotent: calling it on a + xorq-vendored table is a cheap no-op. + """ + _patch_xorq_sortkey_compat() + if "xorq.vendor.ibis" not in type(table).__module__: + try: + from .._xorq import from_ibis + + return from_ibis(table) + except Exception: + # Backend isn't supported by xorq's map_ibis registry (e.g. + # Databricks). Fall back so plain-ibis paths can still execute. + logger.debug( + "from_ibis failed for %s; using plain ibis table", + type(table).__module__, + exc_info=True, + ) + return table + return table + + +def _rebind_to_backend(expr, target_backend): + """Rebind every ``DatabaseTable`` op in *expr* to *target_backend*. + + Low-level primitive shared with ``serialization.reconstruct``. + No-op on plain ibis expressions or when xorq is unavailable for any + reason; callers must pass a xorq-vendored ``target_backend``. + """ + try: + from .._xorq import relations as xorq_rel + except ImportError: + return expr + + def _recreate(op, _kwargs, **overrides): + kwargs = dict(zip(op.__argnames__, op.__args__, strict=False)) + if _kwargs: + kwargs.update(_kwargs) + kwargs.update(overrides) + return op.__recreate__(kwargs) + + def replacer(op, _kwargs): + if isinstance(op, xorq_rel.DatabaseTable) and op.source is not target_backend: + return _recreate(op, _kwargs, source=target_backend) + if _kwargs: + return _recreate(op, _kwargs) + return op + + return expr.op().replace(replacer).to_expr() + + +def _rebind_to_canonical_backend(expr): + """Rebind divergent ``DatabaseTable`` backends in *expr* to share one. + + ``from_ibis()`` creates a distinct ``Backend`` per call, so expressions + built by composing separately-converted tables contain multiple + backends. Picking the first ``DatabaseTable``'s source as canonical + and rebinding the rest eliminates "Multiple backends found" errors + at execution time. + + No-op on plain ibis expressions (not xorq-vendored). + """ + try: + from .._xorq import relations as xorq_rel, walk_nodes + except ImportError: + return expr + + try: + db_tables = list(walk_nodes((xorq_rel.DatabaseTable,), expr)) + except Exception: + # walk_nodes can't traverse plain ibis trees; treat as no-op. + logger.debug("walk_nodes failed on plain ibis expr", exc_info=True) + return expr + + canonical = db_tables[0].source if db_tables else None + if canonical is None: + return expr + + return _rebind_to_backend(expr, canonical)