diff --git a/src/boring_semantic_layer/serialization/tag_handler.py b/src/boring_semantic_layer/serialization/tag_handler.py index 1cc065f..91df1b9 100644 --- a/src/boring_semantic_layer/serialization/tag_handler.py +++ b/src/boring_semantic_layer/serialization/tag_handler.py @@ -29,22 +29,49 @@ def extract_metadata(tag_node) -> dict[str, Any]: """Return sidecar metadata (dimension/measure names) for a BSL-tagged node. - Walks nested metadata (the ``source`` chain) down to the innermost - ``SemanticTableOp`` and extracts the dimension/measure name tuples. + Walks nested metadata down to every ``SemanticTableOp`` leaf and unions + their dimension / measure / calc-measure names. ``source`` chains are + descended through; ``SemanticJoinOp`` nodes branch into ``left`` and + ``right``. Names from a joined leaf are prefixed with the leaf's table + name (matching how a joined ``SemanticTable`` exposes its fields, e.g. + ``flights.flight_count``); a non-joined model returns flat names. """ - table_meta: Any = tag_node.metadata - while table_meta.get("bsl_op_type") != "SemanticTableOp" and ( - src := table_meta.get("source") - ): - table_meta = dict(src) if isinstance(src, tuple) else src - dims = tuple(d[0] for d in table_meta.get("dimensions", ())) - measures = tuple(m[0] for m in table_meta.get("measures", ())) - return { + + def as_dict(meta: Any) -> dict[str, Any]: + return dict(meta) if isinstance(meta, tuple) else meta + + def collect(meta: Any, *, in_join: bool) -> tuple[list[str], list[str], list[str]]: + meta = as_dict(meta) + op_type = meta.get("bsl_op_type") + + if (src := meta.get("source")) is not None: + return collect(src, in_join=in_join) + + if op_type == "SemanticJoinOp": + ld, lm, lc = collect(meta.get("left", {}), in_join=True) + rd, rm, rc = collect(meta.get("right", {}), in_join=True) + return ld + rd, lm + rm, lc + rc + + if op_type == "SemanticTableOp": + name = meta.get("name") + prefix = f"{name}." if (in_join and name) else "" + dims = [prefix + d[0] for d in meta.get("dimensions", ())] + meas = [prefix + m[0] for m in meta.get("measures", ())] + calc = [prefix + c[0] for c in meta.get("calc_measures", ())] + return dims, meas, calc + + return [], [], [] + + dims, measures, calc = collect(tag_node.metadata, in_join=False) + result: dict[str, Any] = { "type": "semantic_model", "description": f"{len(dims)} dims, {len(measures)} measures", - "dimensions": dims, - "measures": measures, + "dimensions": tuple(dims), + "measures": tuple(measures), } + if calc: + result["calc_measures"] = tuple(calc) + return result def from_tag_node(tag_node): diff --git a/src/boring_semantic_layer/tests/test_xorq_tag_handler.py b/src/boring_semantic_layer/tests/test_xorq_tag_handler.py index 88623cf..9296da7 100644 --- a/src/boring_semantic_layer/tests/test_xorq_tag_handler.py +++ b/src/boring_semantic_layer/tests/test_xorq_tag_handler.py @@ -17,7 +17,7 @@ import ibis import pytest -from boring_semantic_layer import SemanticModel +from boring_semantic_layer import SemanticModel, to_semantic_table from boring_semantic_layer.serialization import to_tagged from boring_semantic_layer.serialization.tag_handler import ( bsl_tag_handler, @@ -114,6 +114,54 @@ def test_extract_metadata_walks_source_chain(simple_model): assert set(meta["measures"]) == {"sum_b", "avg_b"} +def test_extract_metadata_walks_join_branches(): + """For joined models the handler must descend into both ``left`` and + ``right`` branches and union dim/measure names from every leaf + ``SemanticTableOp``, prefixing them with the leaf's table name to match + how a joined ``SemanticTable`` exposes its fields.""" + t1 = ibis.memtable({"id": [1, 2], "name": ["a", "b"]}) + t2 = ibis.memtable({"id": [1, 2], "value": [10, 20]}) + t3 = ibis.memtable({"id": [1, 2], "extra": ["x", "y"]}) + + st1 = ( + to_semantic_table(t1, name="t1") + .with_dimensions(id=lambda t: t.id, name=lambda t: t.name) + .with_measures(count=lambda t: t.count()) + ) + st2 = ( + to_semantic_table(t2, name="t2") + .with_dimensions(id=lambda t: t.id) + .with_measures(total=lambda t: t.value.sum()) + ) + st3 = ( + to_semantic_table(t3, name="t3") + .with_dimensions(id=lambda t: t.id, extra=lambda t: t.extra) + .with_measures(extra_count=lambda t: t.count()) + ) + + # Two-arm join chain: covers nested SemanticJoinOp on the left as well as + # a query wrapper on top, exercising the same path as the original bug + # where every leaf was being missed. + joined = st1.join_one(st2, on=lambda l, r: l.id == r.id).join_one( + st3, on=lambda l, r: l.id == r.id + ) + query = joined.query(dimensions=("t1.name",), measures=("t1.count",)) + tag_node = _tag_node(to_tagged(query)) + + meta = extract_metadata(tag_node) + + assert set(meta["dimensions"]) == { + "t1.id", + "t1.name", + "t2.id", + "t3.id", + "t3.extra", + } + assert set(meta["measures"]) == {"t1.count", "t2.total", "t3.extra_count"} + assert "5 dims" in meta["description"] + assert "3 measures" in meta["description"] + + # --------------------------------------------------------------------------- # from_tag_node # ---------------------------------------------------------------------------