Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/boring_semantic_layer/serialization/tag_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,37 @@ def from_tag_node(tag_node):
return reconstruct_bsl_operation(metadata, expr, ctx)


bsl_tag_handler = TagHandler(
def reemit(tag_node, rebuild_subexpr):
"""Re-emit a BSL-tagged subtree with a translated source.

``from_tag_node`` returns the base SemanticModel (discarding the query
chain), so it cannot be used for rebuild — rebuild needs the full tag
metadata to reproduce the original query. This function works from the
tag node directly: it rebuilds the source subtree and re-stamps the
original tag metadata on top.
"""
if tag_node.parent is None:
raise ValueError("tag_node has no parent; cannot rebuild a root tag node")
new_source = rebuild_subexpr(tag_node.parent.to_expr())
meta = dict(tag_node.metadata)
tag_name = meta.pop("tag")
return new_source.tag(tag=tag_name, **meta)


_handler_kwargs = dict(
tag_names=("bsl",),
extract_metadata=extract_metadata,
from_tag_node=from_tag_node,
)
if "reemit" in {a.name for a in TagHandler.__attrs_attrs__}:
_handler_kwargs["reemit"] = reemit

bsl_tag_handler = TagHandler(**_handler_kwargs)


__all__ = [
"bsl_tag_handler",
"extract_metadata",
"from_tag_node",
"reemit",
]
313 changes: 313 additions & 0 deletions src/boring_semantic_layer/tests/test_xorq_rebuild.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
"""Tests for BSL rebuild support via ``reemit`` on the TagHandler.

Covers:
- ``reemit`` is registered on ``bsl_tag_handler``
- Identity reemit preserves tag metadata (round-trip invariant)
- ``get_rebuild_dispatch`` returns handler-level reemit for BSL tags
- Catalog rebuild round-trip with BSL entries
- Rebuilt BSL entries execute correctly
- Query-chain (aggregate) rebuild
"""

from __future__ import annotations

from pathlib import Path

import ibis
import pytest

from boring_semantic_layer import SemanticModel
from boring_semantic_layer.serialization import to_tagged
from boring_semantic_layer.serialization.tag_handler import (
bsl_tag_handler,
reemit,
)

xorq = pytest.importorskip("xorq", reason="xorq not installed")

from xorq.expr.builders import TagHandler as _TagHandler

_has_reemit = "reemit" in {a.name for a in _TagHandler.__attrs_attrs__}
requires_reemit = pytest.mark.skipif(
not _has_reemit, reason="xorq TagHandler does not have reemit field"
)


def _tag_node(tagged_expr):
return tagged_expr.op()


# ---------------------------------------------------------------------------
# Phase 2: reemit registration
# ---------------------------------------------------------------------------


@requires_reemit
def test_reemit_registered_on_handler():
assert bsl_tag_handler.reemit is reemit


@requires_reemit
def test_reemit_is_callable():
assert callable(bsl_tag_handler.reemit)


# ---------------------------------------------------------------------------
# Phase 3: identity reemit preserves tag metadata (Invariant B)
# ---------------------------------------------------------------------------


@pytest.fixture
def simple_model():
table = ibis.memtable({"a": [1, 2, 3], "b": [4, 5, 6]})
return SemanticModel(
table=table,
dimensions={"a": lambda t: t.a, "b": lambda t: t.b},
measures={"sum_b": lambda t: t.b.sum(), "avg_b": lambda t: t.b.mean()},
name="simple",
)


@requires_reemit
def test_identity_reemit_preserves_metadata(simple_model):
tagged = to_tagged(simple_model)
original_meta = dict(_tag_node(tagged).metadata)

rebuilt = reemit(_tag_node(tagged), rebuild_subexpr=lambda e: e)
rebuilt_meta = dict(_tag_node(rebuilt).metadata)

assert original_meta == rebuilt_meta


@requires_reemit
def test_identity_reemit_on_query_chain(simple_model):
query = simple_model.query(dimensions=("a",), measures=("sum_b",))
tagged = to_tagged(query)
original_meta = dict(_tag_node(tagged).metadata)

rebuilt = reemit(_tag_node(tagged), rebuild_subexpr=lambda e: e)
rebuilt_meta = dict(_tag_node(rebuilt).metadata)

assert original_meta == rebuilt_meta


@requires_reemit
def test_reemit_with_source_transform(simple_model):
tagged = to_tagged(simple_model)
original_meta = dict(_tag_node(tagged).metadata)

def rename_column(expr):
return expr.rename(a_renamed="a")

rebuilt = reemit(_tag_node(tagged), rebuild_subexpr=rename_column)
rebuilt_meta = dict(_tag_node(rebuilt).metadata)
assert original_meta == rebuilt_meta
assert "a_renamed" in rebuilt.columns


@requires_reemit
def test_reemit_query_chain_with_source_transform(simple_model):
query = simple_model.query(dimensions=("a",), measures=("sum_b",))
tagged = to_tagged(query)
original_meta = dict(_tag_node(tagged).metadata)

def add_column(expr):
return expr.mutate(extra=ibis.literal(1))

rebuilt = reemit(_tag_node(tagged), rebuild_subexpr=add_column)
rebuilt_meta = dict(_tag_node(rebuilt).metadata)
assert original_meta == rebuilt_meta
assert "extra" in rebuilt.columns


# ---------------------------------------------------------------------------
# get_rebuild_dispatch returns handler-level reemit for BSL
# ---------------------------------------------------------------------------


@requires_reemit
def test_get_rebuild_dispatch_returns_callable_for_bsl(simple_model):
from xorq.expr.builders import get_rebuild_dispatch

tagged = to_tagged(simple_model)
dispatch = get_rebuild_dispatch(_tag_node(tagged))
assert callable(dispatch)


@requires_reemit
def test_get_rebuild_dispatch_invokes_handler_reemit(simple_model):
from xorq.expr.builders import get_rebuild_dispatch

tagged = to_tagged(simple_model)
dispatch = get_rebuild_dispatch(_tag_node(tagged))
result = dispatch(lambda e: e)
assert result is not None
rebuilt_meta = dict(_tag_node(result).metadata)
original_meta = dict(_tag_node(tagged).metadata)
assert original_meta == rebuilt_meta


# ---------------------------------------------------------------------------
# Catalog helpers
# ---------------------------------------------------------------------------


def _make_catalog(tmpdir, name="src"):
import xorq.api as xo
from xorq.catalog.backend import GitBackend
from xorq.catalog.catalog import Catalog

repo = Catalog.init_repo_path(Path(tmpdir).joinpath(name))
catalog = Catalog(backend=GitBackend(repo=repo))
return catalog, xo


def _add_source_with_identity_transform(catalog, xo, data, *, source_alias, transform_alias):
from xorq.vendor.ibis.expr import operations as ops

source_expr = xo.memtable(data, name=source_alias)
source_entry = catalog.add(source_expr, aliases=(source_alias,))

unbound = ops.UnboundTable(name="p", schema=source_expr.schema()).to_expr()
identity = unbound.select(*source_expr.columns)
transform_entry = catalog.add(identity, aliases=(transform_alias,))

return source_entry, transform_entry


def _replay_rebuild(source_catalog_obj, target_path):
from xorq.catalog.catalog import Catalog
from xorq.catalog.replay import Replayer

target = Catalog.from_repo_path(target_path, init=True)
Replayer(from_catalog=source_catalog_obj, rebuild=True).replay(target)
return target


# ---------------------------------------------------------------------------
# Catalog rebuild: query chain (SemanticAggregateOp)
# ---------------------------------------------------------------------------


@pytest.fixture
def catalog_with_bsl_query(tmpdir):
from xorq.catalog.bind import bind

catalog, xo = _make_catalog(tmpdir)

source_entry, transform_entry = _add_source_with_identity_transform(
catalog,
xo,
{"origin": ["JFK", "LAX", "ORD"], "delay": [10.0, -5.0, 3.0]},
source_alias="flights",
transform_alias="flights-identity",
)

bound = bind(source_entry, transform_entry)
model = SemanticModel(
table=bound,
dimensions={"origin": lambda t: t.origin},
measures={"avg_delay": lambda t: t.delay.mean()},
name="flights_model",
)
tagged = to_tagged(
model.query(dimensions=("origin",), measures=("avg_delay",))
)
bsl_entry = catalog.add(tagged, aliases=("origin-delays",))

return catalog, source_entry, bsl_entry


@requires_reemit
def test_catalog_rebuild_produces_consistent_target(catalog_with_bsl_query, tmpdir):
catalog, _, _ = catalog_with_bsl_query
target = _replay_rebuild(catalog, Path(tmpdir).joinpath("tgt"))
assert len(target.list()) == len(catalog.list())
assert set(target.list_aliases()) == set(catalog.list_aliases())
target.assert_consistency()


@requires_reemit
def test_catalog_rebuild_bsl_entry_exists(catalog_with_bsl_query, tmpdir):
catalog, _, _ = catalog_with_bsl_query
target = _replay_rebuild(catalog, Path(tmpdir).joinpath("tgt"))
entry = target.get_catalog_entry("origin-delays", maybe_alias=True)
assert entry is not None


@requires_reemit
def test_catalog_rebuild_bsl_entry_executes(catalog_with_bsl_query, tmpdir):
catalog, _, _ = catalog_with_bsl_query
target = _replay_rebuild(catalog, Path(tmpdir).joinpath("tgt"))
entry = target.get_catalog_entry("origin-delays", maybe_alias=True)
result = entry.lazy_expr.execute()
assert len(result) == 3
assert "origin" in result.columns
assert "avg_delay" in result.columns


# ---------------------------------------------------------------------------
# Catalog rebuild: base model (SemanticTableOp)
# ---------------------------------------------------------------------------


@pytest.fixture
def catalog_with_base_model(tmpdir):
from xorq.catalog.bind import bind

catalog, xo = _make_catalog(tmpdir)

source_entry, transform_entry = _add_source_with_identity_transform(
catalog,
xo,
{"city": ["NYC", "LA"], "pop": [8_000_000, 4_000_000]},
source_alias="cities",
transform_alias="cities-identity",
)

bound = bind(source_entry, transform_entry)
model = SemanticModel(
table=bound,
dimensions={"city": lambda t: t.city},
measures={"total_pop": lambda t: t.pop.sum()},
name="city_model",
)
tagged = to_tagged(model)
bsl_entry = catalog.add(tagged, aliases=("city-stats",))

return catalog, source_entry, bsl_entry


@requires_reemit
def test_catalog_rebuild_base_model(catalog_with_base_model, tmpdir):
catalog, _, _ = catalog_with_base_model
target = _replay_rebuild(catalog, Path(tmpdir).joinpath("tgt"))
assert set(target.list_aliases()) == set(catalog.list_aliases())
target.assert_consistency()


@requires_reemit
def test_catalog_rebuild_base_model_executes(catalog_with_base_model, tmpdir):
catalog, _, _ = catalog_with_base_model
target = _replay_rebuild(catalog, Path(tmpdir).joinpath("tgt"))
entry = target.get_catalog_entry("city-stats", maybe_alias=True)
result = entry.lazy_expr.execute()
assert len(result) == 2


# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------


@requires_reemit
def test_reemit_raises_on_missing_parent(simple_model):
tagged = to_tagged(simple_model)
node = _tag_node(tagged)
original_parent = node.parent
try:
node.parent = None
with pytest.raises(ValueError, match="no parent"):
reemit(node, rebuild_subexpr=lambda e: e)
finally:
node.parent = original_parent
Loading