From 4fd1def0be9493c6413a5b75776a7a021bbfc38e Mon Sep 17 00:00:00 2001 From: Egor Kraev Date: Thu, 14 May 2026 15:04:19 +0200 Subject: [PATCH 1/3] DEV-1414: per-bucket ranking invariance in search() Memory/example_query/entity rankings now depend only on the corpus, question, datasource, and that bucket's own cap. Previously a shared ``over_fetch_budget = max(max_memories + max_example_queries, max_entities) * 5`` fed all three channels, so changing any one cap reshuffled the bottom cliff of channels 2 and 3 (mixed memory+entity hits) and rippled into the top of every output bucket. Fix: - search_index gains kind_filter / exclude_kind keyword-only params (mutex), built via tantivy.Query.boolean_query + term_query. - build_in_memory_corpus uses writer(num_threads=1) so tantivy doc-id tiebreak on equal BM25 scores is deterministic across rebuilds. - service.py drops _OVER_FETCH_MULTIPLIER and over_fetch_budget; channels 2 and 3 each run twice (once per kind, with limit = full per-kind corpus size), channel 1 stops truncating. Co-Authored-By: Claude Opus 4.7 (1M context) --- .claude/skills/slayer-overview.md | 2 +- CLAUDE.md | 2 +- docs/concepts/search.md | 14 + slayer/search/index.py | 43 ++- slayer/search/service.py | 209 ++++++----- tests/test_search_index.py | 122 +++++++ tests/test_search_invariance.py | 580 ++++++++++++++++++++++++++++++ 7 files changed, 886 insertions(+), 86 deletions(-) create mode 100644 tests/test_search_invariance.py diff --git a/.claude/skills/slayer-overview.md b/.claude/skills/slayer-overview.md index b72637a..bead7f5 100644 --- a/.claude/skills/slayer-overview.md +++ b/.claude/skills/slayer-overview.md @@ -32,7 +32,7 @@ Datasources: `create_datasource`, `list_datasources`, `describe_datasource` (inc Ingestion: `ingest_datasource_models` Schema drift: `validate_models` (read-only diff against live schema; surfaces `SchemaDriftError` cleanups) Memory write side: `save_memory`, `forget_memory` (per-entity learnings indexed by canonical entity strings — see [memories.md](../../docs/concepts/memories.md)) -Search: `search` (three-channel: entity-overlap BM25 over memories + tantivy full-text over memories ∪ entities + optional dense embedding similarity, RRF-fused; embeddings require the `embedding_search` extra and degrade gracefully when unavailable; partitions query-bearing memories into `example_queries` — see [search.md](../../docs/concepts/search.md)) +Search: `search` (three-channel: entity-overlap BM25 over memories + tantivy full-text + optional dense embedding similarity, RRF-fused per kind so each output bucket — `memories` / `example_queries` / `entities` — has membership/order invariant under the other buckets' caps; embeddings require the `embedding_search` extra and degrade gracefully when unavailable; partitions query-bearing memories into `example_queries` — see [search.md](../../docs/concepts/search.md)) ## Package Structure diff --git a/CLAUDE.md b/CLAUDE.md index f4d31ba..c157550 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -93,7 +93,7 @@ poetry run ruff check slayer/ tests/ - **Memories + semantic search** (DEV-1357 + DEV-1375): An agent-memory layer indexed by canonical entity strings. Two write-side tools — `save_memory(learning, linked_entities)` and `forget_memory(id)` — record per-entity notes (optionally bundled with an example `SlayerQuery`). Retrieval is unified into a single `search(entities, query, question, max_memories=5, max_example_queries=2, max_entities=5)` tool — there is no separate `recall_memories` surface. `linked_entities` accepts either a list of entity strings (resolved strictly) or an inline `SlayerQuery`/dict (entities auto-extracted; warnings non-fatal; the query is persisted on the memory). The canonical form is exactly one of ``, `.`, `..` (≤ 3 dotted segments after canonicalisation). Aggregation suffixes are stripped (`revenue:sum` → `..revenue`); `*:count` collapses to the source model; multi-hop dotted paths keep only the leaf (`orders.customers.regions.name` → `{.orders, .regions.name}`). The resolver lives in `slayer/memories/resolver.py`; the unified `Memory` row + storage primitives are concrete on `StorageBackend` (ID format / entity-intersection filter), with backends only implementing the row-shaped CRUD + a one-line `_next_memory_seq` that derives the next id from the existing corpus. `inspect_model` auto-renders a `Learnings` section listing only memories where `query is None`; query-bearing memories surface only via `search` (in the `example_queries` bucket). **Memory ids** (DEV-1405): positive ints that increase monotonically while the corpus grows; YAMLStorage derives the next id from the last row of `memories.yaml`, SQLiteStorage from `SELECT MAX(id) + 1 FROM memories`. Ids of deleted memories may be reused by future saves; `delete_memory` already cascades to the matching embedding row so reuse strands no data. - `search` runs up to three parallel channels merged by RRF (DEV-1386 adds the third). **Channel 1** is entity-overlap BM25 over memories (`slayer/memories/ranker.py` using `rank_bm25.BM25Plus`, DEV-1365) — a precisely-tagged memory outranks one with a long entity list that overlaps incidentally. **Channel 2** is a fresh in-memory tantivy index built per call over memories ∪ entities (datasources / non-hidden models / non-hidden columns / named measures / aggregations), using tantivy's `en_stem` analyzer (Porter stemmer + default tokenizer, splits on `_` and `.`). **Channel 3** (DEV-1386, optional via the `embedding_search` pip extra) is dense embedding similarity over the same memories ∪ entities corpus, computed numpy-only against rows persisted in a sidecar `embeddings` table keyed by `(canonical_id, embedding_model_name)`. The SQL lives in `slayer/storage/sidecar_embedding_store.py` (DEV-1405) — both `SQLiteStorage` and `YAMLStorage` instantiate a `SidecarEmbeddingStore` and forward all embedding CRUD to it. SQLiteStorage points it at the main `.db` file; YAMLStorage points it at a dedicated `/embeddings.db` sidecar so the YAML store keeps its git-diffable shape while embeddings live in a fast indexed store. **Cascade semantics** (DEV-1405 fix): `delete_embeddings_for_canonical(canonical_id_prefix=X)` matches the canonical id exactly OR as a strict dotted-path descendant (`X + "." + …`) — never as a character prefix, so deleting `memory:4` no longer also nukes `memory:42`. **Hot-path batching** (DEV-1405): `StorageBackend` exposes `save_embeddings(rows)` and `get_embeddings_for_canonical_ids(canonical_ids, embedding_model_name)` with default M-iteration impls, overridden by the bundled backends to issue single batched round-trips through `SidecarEmbeddingStore.save_many` / `get_many`; `EmbeddingService._apply_pending` uses them so one `refresh_model_subtree` issues exactly one batched read + one batched write regardless of subtree size. The active embedding model is read from `SLAYER_EMBEDDING_MODEL` (default `openai/text-embedding-3-small`) and dispatched via litellm; provider credentials are read by litellm directly (`OPENAI_API_KEY`, etc.). When the extra is not installed, the model has no rows, or the query embedding call fails, channel 3 contributes nothing and emits a single warning into `SearchResponse.warnings`; tantivy + BM25 continue to work. Refresh runs inline on `slayer ingest` / `edit_model` / `save_memory` and skips the litellm call when the rendered `content_hash` matches the stored row (cheap idempotent re-runs). Per-entity embed failures are non-fatal — search degrades gracefully. Memory rankings from every active channel are fused via Reciprocal Rank Fusion (`k=60`, hand-rolled in `slayer/search/rrf.py`); **entity hits from channels 2 and 3 are now also RRF-fused** (channel 1 contributes only to memory ranking). Memory hits are partitioned by `Memory.query is None` into `memories` (learning-only, small) and `example_queries` (query-bearing, bulky) — independent caps via `max_memories` and `max_example_queries` so bulky examples cannot crowd out small learnings. The response also echoes `resolved_input_entities` for diagnostics. Empty-input fallback returns the newest `max_memories` learning-only + newest `max_example_queries` query-bearing memories with a warning. Each indexed entity carries a `text` field rendered by `slayer/search/render.py` — named children (columns / measures / aggregations / join targets) are mentioned by name + kind only (no descriptions, since each child has its own indexed doc), while non-named children (model filters, model `sql` block, join `pairs`, aggregation `params`) are included in full. `meta` is **excluded** from indexed text (DEV-1377 hardening). Hidden models / hidden columns are skipped. **`datasource` filter** (DEV-1409): all four surfaces (`MCP search`, `POST /search`, `slayer search --datasource`, `SlayerClient.search`) accept an optional `datasource: Optional[str] = None`. When set, every channel pre-filters its corpus to that one datasource — entity hits only include docs rooted at it (exact name or strict dotted-path descendant); memories surface when any of their `entities` is rooted at it (memories spanning multiple datasources surface from each); BM25 / IDF / cosine corpus reflect only the filtered subset. Unknown datasource → `ValueError` (HTTP 400 on REST). Helper: `slayer.memories.resolver.canonical_id_rooted_at`. + `search` runs up to three parallel channels merged by RRF (DEV-1386 adds the third). **Channel 1** is entity-overlap BM25 over memories (`slayer/memories/ranker.py` using `rank_bm25.BM25Plus`, DEV-1365) — a precisely-tagged memory outranks one with a long entity list that overlaps incidentally. **Channel 2** is a fresh in-memory tantivy index built per call over memories ∪ entities (datasources / non-hidden models / non-hidden columns / named measures / aggregations), using tantivy's `en_stem` analyzer (Porter stemmer + default tokenizer, splits on `_` and `.`). **Channel 3** (DEV-1386, optional via the `embedding_search` pip extra) is dense embedding similarity over the same memories ∪ entities corpus, computed numpy-only against rows persisted in a sidecar `embeddings` table keyed by `(canonical_id, embedding_model_name)`. The SQL lives in `slayer/storage/sidecar_embedding_store.py` (DEV-1405) — both `SQLiteStorage` and `YAMLStorage` instantiate a `SidecarEmbeddingStore` and forward all embedding CRUD to it. SQLiteStorage points it at the main `.db` file; YAMLStorage points it at a dedicated `/embeddings.db` sidecar so the YAML store keeps its git-diffable shape while embeddings live in a fast indexed store. **Cascade semantics** (DEV-1405 fix): `delete_embeddings_for_canonical(canonical_id_prefix=X)` matches the canonical id exactly OR as a strict dotted-path descendant (`X + "." + …`) — never as a character prefix, so deleting `memory:4` no longer also nukes `memory:42`. **Hot-path batching** (DEV-1405): `StorageBackend` exposes `save_embeddings(rows)` and `get_embeddings_for_canonical_ids(canonical_ids, embedding_model_name)` with default M-iteration impls, overridden by the bundled backends to issue single batched round-trips through `SidecarEmbeddingStore.save_many` / `get_many`; `EmbeddingService._apply_pending` uses them so one `refresh_model_subtree` issues exactly one batched read + one batched write regardless of subtree size. The active embedding model is read from `SLAYER_EMBEDDING_MODEL` (default `openai/text-embedding-3-small`) and dispatched via litellm; provider credentials are read by litellm directly (`OPENAI_API_KEY`, etc.). When the extra is not installed, the model has no rows, or the query embedding call fails, channel 3 contributes nothing and emits a single warning into `SearchResponse.warnings`; tantivy + BM25 continue to work. Refresh runs inline on `slayer ingest` / `edit_model` / `save_memory` and skips the litellm call when the rendered `content_hash` matches the stored row (cheap idempotent re-runs). Per-entity embed failures are non-fatal — search degrades gracefully. Memory rankings from every active channel are fused via Reciprocal Rank Fusion (`k=60`, hand-rolled in `slayer/search/rrf.py`); **entity hits from channels 2 and 3 are now also RRF-fused** (channel 1 contributes only to memory ranking). Memory hits are partitioned by `Memory.query is None` into `memories` (learning-only, small) and `example_queries` (query-bearing, bulky) — independent caps via `max_memories` and `max_example_queries` so bulky examples cannot crowd out small learnings. The response also echoes `resolved_input_entities` for diagnostics. Empty-input fallback returns the newest `max_memories` learning-only + newest `max_example_queries` query-bearing memories with a warning. Each indexed entity carries a `text` field rendered by `slayer/search/render.py` — named children (columns / measures / aggregations / join targets) are mentioned by name + kind only (no descriptions, since each child has its own indexed doc), while non-named children (model filters, model `sql` block, join `pairs`, aggregation `params`) are included in full. `meta` is **excluded** from indexed text (DEV-1377 hardening). Hidden models / hidden columns are skipped. **`datasource` filter** (DEV-1409): all four surfaces (`MCP search`, `POST /search`, `slayer search --datasource`, `SlayerClient.search`) accept an optional `datasource: Optional[str] = None`. When set, every channel pre-filters its corpus to that one datasource — entity hits only include docs rooted at it (exact name or strict dotted-path descendant); memories surface when any of their `entities` is rooted at it (memories spanning multiple datasources surface from each); BM25 / IDF / cosine corpus reflect only the filtered subset. Unknown datasource → `ValueError` (HTTP 400 on REST). Helper: `slayer.memories.resolver.canonical_id_rooted_at`. **Per-bucket ranking invariance** (DEV-1414): channel 2 runs as two kind-filtered tantivy queries (one over memory docs, one over entity docs); channel 3 partitions the embedding corpus by `entity_kind` and ranks each side independently. There is no shared candidate-pool budget across kinds, so for a fixed `(question, datasource, max_X)` the membership and order of each output bucket (`memories` / `example_queries` / `entities`) is a pure function of the corpus + question + that one cap — varying the other two caps cannot move ids in or out of the returned list nor reorder it. The kind-filtered tantivy queries are emitted as boolean queries via `tantivy.Query.boolean_query` + `tantivy.Query.term_query` (`search_index`'s new `kind_filter` / `exclude_kind` params). The in-memory tantivy index is built with `writer(num_threads=1)` so doc-id tiebreak on equal BM25 scores is deterministic across rebuilds. Sample-value snapshots cached on `Column.sampled` (v6 schema bump, no-op forward migration in `slayer/storage/v6_migration.py`); refreshed on every `slayer ingest` for table-backed models, on `slayer search refresh-samples`, on `edit_model` (column-level edits → that column; `model.filters` / `model.sql` / `source_queries` change → all columns), and lazily on `inspect_model` cache miss (best-effort write-back). sql-mode and query-backed sample-value coverage is deferred to [DEV-1377](https://linear.app/motley-ai/issue/DEV-1377). Surfaces: write side via MCP, REST (`POST /memories`, `DELETE /memories/{id}`), CLI (`slayer memory {save,forget}`), and `SlayerClient`; retrieval via MCP (`search`), REST (`POST /search`), CLI (`slayer search [--entity ...] [--query ...] [--question ...] [--max-example-queries N]`, `slayer search refresh-samples`), and `SlayerClient.search()`. See [docs/concepts/memories.md](docs/concepts/memories.md) and [docs/concepts/search.md](docs/concepts/search.md). diff --git a/docs/concepts/search.md b/docs/concepts/search.md index 2578999..3374bf4 100644 --- a/docs/concepts/search.md +++ b/docs/concepts/search.md @@ -112,6 +112,20 @@ Entity rankings from channels 2 and 3 are RRF-fused the same way. Channel 1 contributes to the memory ranking only (it operates on memory entity tags, not on entity docs). +### Per-bucket ranking invariance (DEV-1414) + +Each channel produces a **full per-kind ranking** — channel 2 runs as +two kind-filtered tantivy queries (one over memory docs only, one over +entity docs only), and channel 3 partitions the embedding corpus by +`entity_kind` and ranks each side independently. There is no shared +candidate-pool budget across kinds, so for a fixed +`(question, datasource, max_X)` the membership and order of the +returned `X` bucket (`memories` / `example_queries` / `entities`) is a +pure function of the corpus + question + that one cap. Varying the +other two caps cannot move an id in or out of the returned list nor +reorder it. The `max_*` caps are pure post-fusion slice operations on +the three independent ranked lists. + ## Tool surface ```python diff --git a/slayer/search/index.py b/slayer/search/index.py index f1302ae..88cbf57 100644 --- a/slayer/search/index.py +++ b/slayer/search/index.py @@ -192,7 +192,15 @@ def build_in_memory_corpus( """ schema = _build_schema() index = tantivy.Index(schema=schema) - writer = index.writer() + # `num_threads=1` pins doc-id assignment to insertion order so the + # tantivy tiebreak (lower internal doc id wins on equal scores) is + # deterministic across rebuilds (DEV-1414). The default + # ``num_threads=0`` lets tantivy auto-pick a thread count, and with + # multiple writer threads the order in which threads commit their + # local segments determines doc-id assignment — which is + # non-deterministic for small in-RAM corpora that finish + # processing within microseconds. + writer = index.writer(num_threads=1) visible_models = [m for m in models if not m.hidden] pairs = _collect_render_pairs( @@ -240,6 +248,8 @@ def search_index( question: str, limit: int = 20, fields: Optional[List[str]] = None, + kind_filter: Optional[str] = None, + exclude_kind: Optional[str] = None, ) -> List[IndexHit]: """Run a tantivy query against ``index``. @@ -250,10 +260,23 @@ def search_index( limit: Max hits to return. fields: Which schema fields to query against (default: ``["text"]``). Pass ``["canonical"]`` for an exact-match canonical lookup. + kind_filter: When set, restrict results to docs whose ``kind`` + field exactly equals this value (e.g. ``"memory"``, + ``"model"``). Combined with the text query via ``Must``. + exclude_kind: When set, exclude docs whose ``kind`` field equals + this value. Combined with the text query via ``MustNot``. + ``kind_filter`` and ``exclude_kind`` are mutually exclusive + (DEV-1414): one is for keeping a single kind, the other for + dropping a single kind. Pass at most one. Returns: List of :class:`IndexHit` in score-desc order. """ + if kind_filter is not None and exclude_kind is not None: + raise ValueError( + "kind_filter and exclude_kind are mutually exclusive; pass " + "at most one." + ) if not question or not question.strip(): return [] if fields is None: @@ -262,6 +285,24 @@ def search_index( query = index.parse_query(question, fields) except (ValueError, RuntimeError): return [] + if kind_filter is not None or exclude_kind is not None: + schema = index.schema + if kind_filter is not None: + kind_term = tantivy.Query.term_query( + schema, "kind", kind_filter, + ) + query = tantivy.Query.boolean_query([ + (tantivy.Occur.Must, query), + (tantivy.Occur.Must, kind_term), + ]) + else: + kind_term = tantivy.Query.term_query( + schema, "kind", exclude_kind, + ) + query = tantivy.Query.boolean_query([ + (tantivy.Occur.Must, query), + (tantivy.Occur.MustNot, kind_term), + ]) searcher = index.searcher() raw_hits = searcher.search(query, limit).hits out: List[IndexHit] = [] diff --git a/slayer/search/service.py b/slayer/search/service.py index bf9a07a..5d3dd67 100644 --- a/slayer/search/service.py +++ b/slayer/search/service.py @@ -4,21 +4,32 @@ (``slayer.memories.ranker.bm25_rank``). Skipped when neither ``entities`` nor ``query`` is supplied. Contributes only to the memory ranking. -* **Channel 2** — tantivy full-text over memories ∪ entities. Skipped - when ``question`` is empty. Contributes to both the memory ranking - and the entity ranking. -* **Channel 3** — dense embedding similarity over memories ∪ entities - (DEV-1386). Skipped when ``question`` is empty, when the - ``embedding_search`` extra is not installed, when the query - embedding call fails, or when there are no embedding rows for the - active model name. Contributes to both the memory ranking and the - entity ranking. +* **Channel 2** — tantivy full-text. Skipped when ``question`` is + empty. Runs as TWO kind-filtered queries per call (DEV-1414): one + with ``kind_filter="memory"`` for the memory ranking, one with + ``exclude_kind="memory"`` for the entity ranking. Each query ranks + the full per-kind subset of the corpus — no over-fetch truncation. +* **Channel 3** — dense embedding similarity (DEV-1386). Skipped when + ``question`` is empty, when the ``embedding_search`` extra is not + installed, when the query embedding call fails, or when there are + no embedding rows for the active model name. The persisted embedding + rows are partitioned by ``entity_kind`` (DEV-1414): memory rows feed + the memory ranking, non-memory rows feed the entity ranking. Each + partition is ranked in full. Memory rankings from every active channel are fused via RRF (``k = 60``). Entity rankings from channels 2 and 3 are fused the same way. Channel 1 does not contribute to entity ranking (it operates on memory entity tags, not on entity docs). +Per-bucket invariance (DEV-1414): because each channel produces a full +per-kind ranking — never truncated by a shared candidate-pool budget — +the membership and order of every output bucket (``memories``, +``example_queries``, ``entities``) is a pure function of the corpus, +the question, the datasource filter, and that bucket's own cap. Varying +the other two caps cannot move ids in or out of the returned list nor +reorder it. + Empty input (no entities, no query, no question) falls back to recency: newest ``max_memories`` learning-only memories + newest ``max_example_queries`` query-bearing memories, with a warning. @@ -52,7 +63,6 @@ _RRF_K = 60 -_OVER_FETCH_MULTIPLIER = 5 # --------------------------------------------------------------------------- @@ -133,23 +143,6 @@ def _dedup(items: List[str]) -> List[str]: return out -def _split_tantivy_hits( - hits: List[IndexHit], -) -> Tuple[List[int], List[IndexHit], dict[int, IndexHit]]: - """Sort tantivy hits into the memory ranking, the entity-hit list, and - a memory-id→hit lookup.""" - memory_ranking: List[int] = [] - entity_hits: List[IndexHit] = [] - by_memory_id: dict[int, IndexHit] = {} - for hit in hits: - if hit.kind == "memory" and hit.memory_id is not None: - memory_ranking.append(hit.memory_id) - by_memory_id[hit.memory_id] = hit - else: - entity_hits.append(hit) - return memory_ranking, entity_hits, by_memory_id - - def _backfill_memory_by_id( *, memory_by_id: dict, @@ -278,27 +271,29 @@ def _filter_embedding_corpus_by_datasource( ] -def _split_embedding_pairs( - pairs: List[Tuple[int, float]], - rows: List["Embedding"], -) -> Tuple[List[int], List[str]]: - """Split ``(row_idx, score)`` pairs from cosine top-k into - ``(memory_ranking, entity_ranking)``. Skips memory rows whose - ``canonical_id`` doesn't parse back to a valid int — those would - only appear if a backend smuggled in a malformed row.""" - memory_ranking: List[int] = [] - entity_ranking: List[str] = [] - for idx, _score in pairs: - row = rows[idx] - if row.entity_kind == "memory": - try: - memory_id = int(row.canonical_id.split(":", 1)[1]) - except (IndexError, ValueError): - continue - memory_ranking.append(memory_id) +def _count_corpus_kinds(corpus: Corpus) -> Tuple[int, int]: + """Return ``(memory_count, entity_count)`` for a built corpus. Used + by channel 2 to pass ``limit = full per-kind corpus size`` to each + kind-filtered tantivy query so neither kind's ranking is truncated + (DEV-1414).""" + memory_count = 0 + entity_count = 0 + for kind in corpus.canonical_to_kind.values(): + if kind == "memory": + memory_count += 1 else: - entity_ranking.append(row.canonical_id) - return memory_ranking, entity_ranking + entity_count += 1 + return memory_count, entity_count + + +def _memory_id_from_canonical(canonical_id: str) -> Optional[int]: + """Parse a memory row's canonical id back into the int memory id. + Returns ``None`` if the format is malformed — only possible if a + backend smuggled in a non-conforming row.""" + try: + return int(canonical_id.split(":", 1)[1]) + except (IndexError, ValueError): + return None def _fuse_entity_hits( @@ -384,12 +379,6 @@ async def search( warnings=warnings, ) - # Over-fetch must cover the worst case where every fused hit lands - # in the same bucket (all learnings, or all example queries). - over_fetch_budget = max( - max_memories + max_example_queries, max_entities, - ) * _OVER_FETCH_MULTIPLIER - # Single memory-corpus fetch shared by all channels. Pre-filtered # by ``datasource`` so BM25 (channel 1) and the embedding cosine # (channel 3) consume the narrowed list — IDF / matrix shape @@ -418,7 +407,6 @@ async def search( channel_1_memory_ranking, memory_by_id = self._run_channel_1( canonical_input_entities=canonical_input_entities, all_memories=all_memories, - over_fetch=over_fetch_budget, channel_1_active=channel_1_active, ) ( @@ -428,7 +416,6 @@ async def search( ) = self._run_channel_2( corpus=corpus, question=question, - over_fetch=over_fetch_budget, ) ( channel_3_memory_ranking, @@ -437,7 +424,6 @@ async def search( ) = await self._run_channel_3( question=question, corpus=corpus, - over_fetch=over_fetch_budget, question_active=question_active, datasource=datasource, eligible_memory_canonicals={f"memory:{m.id}" for m in all_memories}, @@ -576,10 +562,10 @@ def _run_channel_1( *, canonical_input_entities: List[str], all_memories: List[Memory], - over_fetch: int, channel_1_active: bool, ) -> Tuple[List[int], dict[int, Memory]]: - """Entity-overlap BM25 channel.""" + """Entity-overlap BM25 channel. Ranks the full memory corpus — + no candidate-pool truncation (DEV-1414).""" channel_1_memory_ranking: List[int] = [] memory_by_id: dict[int, Memory] = {} if channel_1_active and canonical_input_entities: @@ -587,7 +573,7 @@ def _run_channel_1( memories=all_memories, query_entities=canonical_input_entities, ) - for memory, _score in ranked[:over_fetch]: + for memory, _score in ranked: memory_by_id[memory.id] = memory channel_1_memory_ranking.append(memory.id) return channel_1_memory_ranking, memory_by_id @@ -597,21 +583,50 @@ def _run_channel_2( *, corpus: Optional[Corpus], question: Optional[str], - over_fetch: int, ) -> Tuple[List[int], List[str], dict[int, IndexHit]]: - """Tantivy full-text channel. Returns - ``(memory_ranking, entity_ranking_canonicals, by_memory_id_hits)``. - Empty when ``corpus`` or ``question`` is missing.""" + """Tantivy full-text channel. + + DEV-1414: runs as TWO kind-filtered queries — one over memory + docs only, one over entity docs only — so the per-kind ranking + is a pure function of the corpus + question, never affected by + the other kind's cap. The ``limit`` for each call is the size of + the corresponding kind in the corpus, so each query returns the + complete per-kind ranking. + + Returns ``(memory_ranking, entity_ranking_canonicals, + by_memory_id_hits)``. Empty when ``corpus`` or ``question`` is + missing. + """ if corpus is None or not question or not question.strip(): return [], [], {} - tantivy_hits = search_index( - index=corpus.index, question=question, limit=over_fetch, + memory_count, entity_count = _count_corpus_kinds(corpus) + memory_hits = ( + search_index( + index=corpus.index, + question=question, + limit=memory_count, + kind_filter="memory", + ) + if memory_count > 0 + else [] ) - ( - memory_ranking, - entity_hits, - by_memory_id, - ) = _split_tantivy_hits(tantivy_hits) + entity_hits = ( + search_index( + index=corpus.index, + question=question, + limit=entity_count, + exclude_kind="memory", + ) + if entity_count > 0 + else [] + ) + memory_ranking: List[int] = [] + by_memory_id: dict[int, IndexHit] = {} + for hit in memory_hits: + if hit.memory_id is None: + continue + memory_ranking.append(hit.memory_id) + by_memory_id[hit.memory_id] = hit entity_ranking = [h.id for h in entity_hits] return memory_ranking, entity_ranking, by_memory_id @@ -620,7 +635,6 @@ async def _run_channel_3( *, question: Optional[str], corpus: Optional[Corpus], - over_fetch: int, question_active: bool, datasource: Optional[str] = None, eligible_memory_canonicals: Optional[Set[str]] = None, @@ -628,6 +642,10 @@ async def _run_channel_3( """Embedding-similarity channel (DEV-1386). Returns ``(memory_ranking, entity_ranking_canonicals, warnings)``. + DEV-1414: the corpus is partitioned by ``entity_kind`` and each + kind is ranked in full via two cosine calls. The per-kind + ranking is a pure function of the corpus + question. + Skipped (with a warning) when: * ``question`` is empty, @@ -644,9 +662,6 @@ async def _run_channel_3( * memory rows whose ``canonical_id`` appears in the supplied ``eligible_memory_canonicals`` set (already datasource-filtered upstream). - - Both `rows[idx]` indexing and the matrix stay aligned because - the filter happens before `np.array(...)`. """ if not question_active or corpus is None: return [], [], [] @@ -692,20 +707,48 @@ async def _run_channel_3( return [], [], [ "embedding channel skipped: query embedding failed.", ] - matrix = np.array([r.embedding for r in rows], dtype=np.float32) - if matrix.shape[1] != len(query_vec): + # All persisted rows share the active model's dim; sample any + # row to detect a stale-dim corpus before partitioning. + if len(rows[0].embedding) != len(query_vec): return [], [], [ f"embedding channel skipped: dim mismatch " - f"(query={len(query_vec)}, corpus={matrix.shape[1]}). " + f"(query={len(query_vec)}, corpus={len(rows[0].embedding)}). " f"Re-run `slayer ingest` to refresh embeddings against " f"the current model.", ] - pairs = top_k_cosine( - query=normalise(query_vec), - matrix=normalise_matrix(matrix), - k=over_fetch, - ) - memory_ranking, entity_ranking = _split_embedding_pairs(pairs, rows) + + memory_rows = [r for r in rows if r.entity_kind == "memory"] + entity_rows = [r for r in rows if r.entity_kind != "memory"] + normalised_query = normalise(query_vec) + + memory_ranking: List[int] = [] + if memory_rows: + memory_matrix = np.array( + [r.embedding for r in memory_rows], dtype=np.float32, + ) + for idx, _score in top_k_cosine( + query=normalised_query, + matrix=normalise_matrix(memory_matrix), + k=len(memory_rows), + ): + memory_id = _memory_id_from_canonical( + memory_rows[idx].canonical_id, + ) + if memory_id is not None: + memory_ranking.append(memory_id) + + entity_ranking: List[str] = [] + if entity_rows: + entity_matrix = np.array( + [r.embedding for r in entity_rows], dtype=np.float32, + ) + for idx, _score in top_k_cosine( + query=normalised_query, + matrix=normalise_matrix(entity_matrix), + k=len(entity_rows), + ): + entity_ranking.append(entity_rows[idx].canonical_id) + return memory_ranking, entity_ranking, [] async def _collect_index_corpus( diff --git a/tests/test_search_index.py b/tests/test_search_index.py index 41262dd..ba87926 100644 --- a/tests/test_search_index.py +++ b/tests/test_search_index.py @@ -11,12 +11,16 @@ * The `canonical` field supports exact-match lookup of canonical entity strings. * The `kind` field supports filtering memory vs entity hits at query time. +* `search_index` accepts a `kind_filter` / `exclude_kind` parameter for + per-kind queries (DEV-1414). """ from __future__ import annotations from typing import List +import pytest + from slayer.core.enums import DataType from slayer.core.models import Column, SlayerModel from slayer.memories.models import Memory @@ -180,3 +184,121 @@ def test_memory_id_round_trips_as_integer() -> None: hits = search_index(index=idx, question="anonymous checkouts", limit=10) memory_hits = [h for h in hits if h.kind == "memory"] assert any(h.memory_id == 2 for h in memory_hits) + + +# --------------------------------------------------------------------------- +# DEV-1414: per-kind filtering for invariant per-bucket ranking +# --------------------------------------------------------------------------- + + +def test_kind_filter_memory_returns_only_memory_hits() -> None: + """`kind_filter="memory"` must restrict results to memory-kind docs + only, even when other-kind docs would otherwise outscore them.""" + idx = build_in_memory_index( + memories=_make_memories(), + models=_make_models(), + datasources=["warehouse"], + ) + hits = search_index( + index=idx, question="customer", limit=50, kind_filter="memory", + ) + assert hits, "expected at least one memory hit" + assert all(h.kind == "memory" for h in hits) + + +def test_kind_filter_model_returns_only_model_hits() -> None: + idx = build_in_memory_index( + memories=_make_memories(), + models=_make_models(), + datasources=["warehouse"], + ) + hits = search_index( + index=idx, question="customer", limit=50, kind_filter="model", + ) + assert hits, "expected at least one model hit" + assert all(h.kind == "model" for h in hits) + + +def test_exclude_kind_memory_returns_only_non_memory_hits() -> None: + """`exclude_kind="memory"` must return entity (datasource/model/column/ + measure/aggregation) hits and no memory hits.""" + idx = build_in_memory_index( + memories=_make_memories(), + models=_make_models(), + datasources=["warehouse"], + ) + hits = search_index( + index=idx, question="customer", limit=50, exclude_kind="memory", + ) + assert hits, "expected at least one entity hit" + assert all(h.kind != "memory" for h in hits) + + +def test_kind_filter_and_exclude_kind_are_mutually_exclusive() -> None: + idx = build_in_memory_index( + memories=_make_memories(), + models=_make_models(), + datasources=["warehouse"], + ) + with pytest.raises(ValueError): + search_index( + index=idx, + question="customer", + kind_filter="memory", + exclude_kind="memory", + ) + + +def test_kind_filtered_query_preserves_relative_score_order() -> None: + """The filter is a pure subset operation, not a re-scoring. For + every pair of memory docs (a, b), their relative order under + `kind_filter="memory"` must match their relative order in an + unfiltered query against the same `question`.""" + idx = build_in_memory_index( + memories=_make_memories(), + models=_make_models(), + datasources=["warehouse"], + ) + unfiltered = search_index(index=idx, question="customer", limit=50) + memory_only = search_index( + index=idx, question="customer", limit=50, kind_filter="memory", + ) + unfiltered_memory_order = [ + h.id for h in unfiltered if h.kind == "memory" + ] + memory_only_order = [h.id for h in memory_only] + assert memory_only_order == unfiltered_memory_order + + +def test_exclude_kind_query_preserves_relative_score_order() -> None: + """Symmetric: `exclude_kind="memory"` preserves the relative order of + non-memory hits as they appear in the unfiltered query.""" + idx = build_in_memory_index( + memories=_make_memories(), + models=_make_models(), + datasources=["warehouse"], + ) + unfiltered = search_index(index=idx, question="customer", limit=50) + non_memory = search_index( + index=idx, question="customer", limit=50, exclude_kind="memory", + ) + unfiltered_entity_order = [ + h.id for h in unfiltered if h.kind != "memory" + ] + non_memory_order = [h.id for h in non_memory] + assert non_memory_order == unfiltered_entity_order + + +def test_kind_filter_limit_clamps_to_kind_subset() -> None: + """`limit` continues to cap the number of returned hits; with + `kind_filter` set, the cap applies to the kind-restricted result list.""" + idx = build_in_memory_index( + memories=_make_memories(), + models=_make_models(), + datasources=["warehouse"], + ) + hits = search_index( + index=idx, question="customer", limit=1, kind_filter="memory", + ) + assert len(hits) <= 1 + assert all(h.kind == "memory" for h in hits) diff --git a/tests/test_search_invariance.py b/tests/test_search_invariance.py new file mode 100644 index 0000000..62c8b06 --- /dev/null +++ b/tests/test_search_invariance.py @@ -0,0 +1,580 @@ +"""Per-bucket ranking invariance (DEV-1414). + +For a fixed ``(question, datasource, max_X)``, the user-visible list of +``X`` (``memories`` / ``example_queries`` / ``entities``) must be a pure +function of the corpus + question + that one cap. Changing the OTHER +two caps must not move any id in or out of the returned ``X`` list, +nor reorder it. + +These tests exercise the bug reported in DEV-1414: the previous +``over_fetch_budget = max(max_memories + max_example_queries, +max_entities) * 5`` shared one candidate-pool cap across all three +channels, so changing ``max_entities`` or ``max_example_queries`` would +push memories in or out of the bottom of each channel's per-kind +ranking — and the membership/order at the top of the fused memory list +would shift even though the question and ``max_memories`` were fixed. +""" + +from __future__ import annotations + +import tempfile +from typing import AsyncIterator, List, Optional + +import pytest +import pytest_asyncio + +from slayer.core.enums import DataType +from slayer.core.models import ( + Column, + DatasourceConfig, + ModelMeasure, + SlayerModel, +) +from slayer.core.query import SlayerQuery +from slayer.embeddings import client as embedding_client +from slayer.search.service import SearchService +from slayer.storage.base import StorageBackend +from slayer.storage.yaml_storage import YAMLStorage + + +# --------------------------------------------------------------------------- +# Corpus fixture +# --------------------------------------------------------------------------- + + +_LEARNING_TOPICS = [ + "amount_paid is gross of refunds", + "filter status='paid' for net revenue", + "customer email may be NULL for anonymous checkouts", + "shipping rates apply only to physical goods", + "tax is computed at checkout, not at order placement", + "refund window is 30 days from order placement", + "loyalty points accrue on net revenue not gross", + "warehouse code 'EU1' is the default Europe warehouse", + "order status 'cancelled' excludes from revenue rollups", + "amount_paid in cents, divide by 100 for dollars", + "customer_id is FK to customers.id", + "checkout sessions older than 24h are abandoned", + "premium customers have customer_tier='gold'", + "free shipping over $50 net of tax", + "anonymous checkouts have NULL customer_id", + "discount_code applies before tax computation", + "order id is monotonic and never reused", + "subscription orders have recurring=true", + "fraud_check is required for orders over $1000", + "currency is always USD for the warehouse dataset", + "customer tier upgrades trigger on $5000 lifetime spend", + "refunded orders retain their original amount_paid", + "shipping warehouse selection is FIFO by region", + "email bounces flip the customer to inactive", + "discount stacking is capped at 30 percent", + "gold tier customers skip the fraud queue", + "warehouse closures move orders to backup region", + "abandoned checkouts older than 7 days are purged", + "customer email change requires re-verification", + "tax exemption applies to gold tier government accounts", + "amount_paid excludes shipping and tax", + "anonymous orders cannot have loyalty points", + "duplicate customer rows are merged on email match", + "warehouse capacity is in physical units not value", + "order amount totals always agree with payment ledger", + "customer_tier is set on first paid order", + "EU2 warehouse opened in Q2 2024", + "refund processing time is 5-7 business days", + "free shipping promo requires registered customer", + "discount_code expiry is checked at checkout", +] + + +def _make_models() -> List[SlayerModel]: + return [ + SlayerModel( + name="orders", + sql_table="public.orders", + data_source="warehouse", + description=( + "Checkout orders fact table including shipping, refund, " + "and tax detail." + ), + columns=[ + Column(name="id", type=DataType.INT, primary_key=True), + Column( + name="customer_id", type=DataType.INT, + description="FK to customers.id, NULL for anonymous.", + ), + Column( + name="amount_paid", type=DataType.DOUBLE, + description="Net paid in cents.", + ), + Column( + name="status", type=DataType.TEXT, + description="paid|refunded|cancelled|abandoned.", + ), + Column( + name="shipped_at", type=DataType.TIMESTAMP, + description="When the order shipped from warehouse.", + ), + Column( + name="discount_code", type=DataType.TEXT, + description="Optional promotional discount code.", + ), + ], + ), + SlayerModel( + name="customers", + sql_table="public.customers", + data_source="warehouse", + description="Customer master data.", + columns=[ + Column(name="id", type=DataType.INT, primary_key=True), + Column( + name="email", type=DataType.TEXT, + description="Customer email; NULL for anonymous.", + ), + Column( + name="customer_tier", type=DataType.TEXT, + description="Tier: gold|silver|standard.", + ), + ], + ), + SlayerModel( + name="warehouses", + sql_table="public.warehouses", + data_source="warehouse", + description="Physical warehouses for fulfilment.", + columns=[ + Column(name="code", type=DataType.TEXT, primary_key=True), + Column(name="region", type=DataType.TEXT), + ], + ), + ] + + +async def _seed_invariance_corpus(storage: StorageBackend) -> None: + """Seed a corpus large enough to exercise the bottom-cliff cases that + used to leak through the shared over_fetch budget.""" + await storage.save_datasource(DatasourceConfig( + name="warehouse", type="sqlite", database=":memory:", + )) + for model in _make_models(): + await storage.save_model(model) + + # 20 learning-only memories. + for i, topic in enumerate(_LEARNING_TOPICS): + # Spread entity tags so different memories surface for different + # questions. + entities: List[str] + if "amount_paid" in topic or "paid" in topic or "revenue" in topic: + entities = ["warehouse.orders.amount_paid"] + elif "email" in topic or "anonymous" in topic: + entities = ["warehouse.customers.email"] + elif "ship" in topic or "warehouse" in topic: + entities = ["warehouse.warehouses"] + elif "customer" in topic and "tier" in topic: + entities = ["warehouse.customers.customer_tier"] + elif "customer" in topic: + entities = ["warehouse.customers"] + elif "status" in topic: + entities = ["warehouse.orders.status"] + elif "discount" in topic: + entities = ["warehouse.orders.discount_code"] + elif "checkout" in topic or "fraud" in topic: + entities = ["warehouse.orders"] + else: + entities = ["warehouse"] + await storage.save_memory( + learning=f"KB{i:02d}: {topic}.", + entities=entities, + ) + + # 8 query-bearing memories — drive the example_queries bucket. + for i in range(8): + await storage.save_memory( + learning=f"Example query {i}: revenue rollup pattern.", + entities=["warehouse.orders.amount_paid"], + query=SlayerQuery( + source_model="orders", + measures=[ModelMeasure(formula="amount_paid:sum")], + ), + ) + + +@pytest_asyncio.fixture +async def storage_with_invariance_corpus() -> AsyncIterator[YAMLStorage]: + with tempfile.TemporaryDirectory() as tmp: + storage = YAMLStorage(base_dir=tmp) + await _seed_invariance_corpus(storage) + yield storage + + +@pytest_asyncio.fixture +async def service_invariance( + storage_with_invariance_corpus: YAMLStorage, +) -> SearchService: + return SearchService(storage=storage_with_invariance_corpus) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +async def _ids(service: SearchService, **kwargs) -> dict[str, list]: + response = await service.search(**kwargs) + return { + "memories": [h.id for h in response.memories], + "example_queries": [h.id for h in response.example_queries], + "entities": [h.id for h in response.entities], + } + + +# --------------------------------------------------------------------------- +# Memory-bucket invariance under entity / example-query caps +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_memories_invariant_under_max_entities( + service_invariance: SearchService, +) -> None: + """Varying ``max_entities`` (with question + datasource + + max_memories + max_example_queries fixed) must not change the + `memories` id list or its order. Tight caps exercise the bottom + cliff in the legacy ``over_fetch_budget``.""" + base = await _ids( + service_invariance, + question="amount paid refund revenue customer email warehouse", + datasource="warehouse", + max_memories=3, + max_example_queries=0, + max_entities=2, + ) + for max_entities in (0, 1, 5, 50, 200): + other = await _ids( + service_invariance, + question="amount paid refund revenue customer email warehouse", + datasource="warehouse", + max_memories=3, + max_example_queries=0, + max_entities=max_entities, + ) + assert other["memories"] == base["memories"], ( + f"memories order changed when max_entities went 2 -> " + f"{max_entities}: {base['memories']} vs {other['memories']}" + ) + + +@pytest.mark.asyncio +async def test_memories_invariant_under_max_example_queries( + service_invariance: SearchService, +) -> None: + base = await _ids( + service_invariance, + question="amount paid refund revenue customer email warehouse", + datasource="warehouse", + max_memories=3, + max_example_queries=0, + max_entities=2, + ) + for max_example_queries in (0, 1, 5, 20, 100): + other = await _ids( + service_invariance, + question="amount paid refund revenue customer email warehouse", + datasource="warehouse", + max_memories=3, + max_example_queries=max_example_queries, + max_entities=2, + ) + assert other["memories"] == base["memories"], ( + f"memories order changed when max_example_queries went 0 -> " + f"{max_example_queries}: {base['memories']} vs " + f"{other['memories']}" + ) + + +# --------------------------------------------------------------------------- +# example_queries-bucket invariance under memory / entity caps +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_example_queries_invariant_under_max_memories( + service_invariance: SearchService, +) -> None: + base = await _ids( + service_invariance, + question="revenue rollup amount paid", + datasource="warehouse", + max_memories=5, + max_example_queries=5, + max_entities=5, + ) + for max_memories in (0, 1, 10, 50): + other = await _ids( + service_invariance, + question="revenue rollup amount paid", + datasource="warehouse", + max_memories=max_memories, + max_example_queries=5, + max_entities=5, + ) + assert other["example_queries"] == base["example_queries"], ( + f"example_queries order changed when max_memories went 5 -> " + f"{max_memories}: {base['example_queries']} vs " + f"{other['example_queries']}" + ) + + +@pytest.mark.asyncio +async def test_example_queries_invariant_under_max_entities( + service_invariance: SearchService, +) -> None: + base = await _ids( + service_invariance, + question="revenue rollup amount paid", + datasource="warehouse", + max_memories=5, + max_example_queries=5, + max_entities=5, + ) + for max_entities in (0, 1, 20, 100): + other = await _ids( + service_invariance, + question="revenue rollup amount paid", + datasource="warehouse", + max_memories=5, + max_example_queries=5, + max_entities=max_entities, + ) + assert other["example_queries"] == base["example_queries"], ( + f"example_queries order changed when max_entities went 5 -> " + f"{max_entities}: {base['example_queries']} vs " + f"{other['example_queries']}" + ) + + +# --------------------------------------------------------------------------- +# entities-bucket invariance under memory / example-query caps +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_entities_invariant_under_max_memories( + service_invariance: SearchService, +) -> None: + base = await _ids( + service_invariance, + question="amount paid refund customer email warehouse shipping", + datasource="warehouse", + max_memories=2, + max_example_queries=0, + max_entities=3, + ) + for max_memories in (0, 1, 20, 100): + other = await _ids( + service_invariance, + question="amount paid refund customer email warehouse shipping", + datasource="warehouse", + max_memories=max_memories, + max_example_queries=0, + max_entities=3, + ) + assert other["entities"] == base["entities"], ( + f"entities order changed when max_memories went 2 -> " + f"{max_memories}: {base['entities']} vs {other['entities']}" + ) + + +@pytest.mark.asyncio +async def test_entities_invariant_under_max_example_queries( + service_invariance: SearchService, +) -> None: + base = await _ids( + service_invariance, + question="amount paid refund customer email warehouse shipping", + datasource="warehouse", + max_memories=2, + max_example_queries=0, + max_entities=3, + ) + for max_example_queries in (0, 1, 5, 30): + other = await _ids( + service_invariance, + question="amount paid refund customer email warehouse shipping", + datasource="warehouse", + max_memories=2, + max_example_queries=max_example_queries, + max_entities=3, + ) + assert other["entities"] == base["entities"], ( + f"entities order changed when max_example_queries went 0 -> " + f"{max_example_queries}: {base['entities']} vs " + f"{other['entities']}" + ) + + +# --------------------------------------------------------------------------- +# DEV-1414 repro tuples +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_dev_1414_repro_tuples_yield_same_top_memories( + service_invariance: SearchService, +) -> None: + """The exact three call shapes from DEV-1414 (max_memories fixed at + the smaller of the two values, varying entity / example-query caps) + must yield identical top-``min(max_memories)`` memory ids. + + Original repro held max_memories=10 across A and B, then bumped to + 15 in C. Compare the prefix of length 10 across all three.""" + call_a = await _ids( + service_invariance, + question="amount paid refund revenue customer email", + datasource="warehouse", + max_memories=10, + max_entities=10, + max_example_queries=5, + ) + call_b = await _ids( + service_invariance, + question="amount paid refund revenue customer email", + datasource="warehouse", + max_memories=10, + max_entities=0, + max_example_queries=0, + ) + call_c = await _ids( + service_invariance, + question="amount paid refund revenue customer email", + datasource="warehouse", + max_memories=15, + max_entities=5, + max_example_queries=2, + ) + # A and B share max_memories=10 → full equality. + assert call_a["memories"] == call_b["memories"] + # C asks for 15 memories; the first 10 must match A and B. + assert call_c["memories"][:10] == call_a["memories"] + + +# --------------------------------------------------------------------------- +# Channel-3 active path (embedding) — invariance must hold there too +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def storage_with_embeddings( + monkeypatch: pytest.MonkeyPatch, +) -> AsyncIterator[YAMLStorage]: + """Same corpus as the base fixture, plus a deterministic embedding + backend stubbed in so channel 3 actually fires.""" + with tempfile.TemporaryDirectory() as tmp: + storage = YAMLStorage(base_dir=tmp) + await _seed_invariance_corpus(storage) + + embedding_client._reset_query_cache() + monkeypatch.setattr(embedding_client, "is_available", lambda: True) + + # Deterministic embeddings: hash the rendered text into a tiny + # vector so ranks vary across docs but are reproducible. + def _vec(text: str) -> List[float]: + return [ + (hash((text, i)) & 0xFFFF) / 65535.0 for i in range(8) + ] + + async def stub_embed_batch( # NOSONAR(S7503) — stub matches embed_batch async signature + texts: List[str], *, model: Optional[str] = None, + ) -> List[Optional[List[float]]]: + return [_vec(t) for t in texts] + + async def stub_embed_query( # NOSONAR(S7503) — stub matches embed_query async signature + text: str, *, model: Optional[str] = None, + ) -> List[float]: + return _vec(text) + + monkeypatch.setattr( + "slayer.embeddings.service.embed_batch", stub_embed_batch, + ) + monkeypatch.setattr( + embedding_client, "embed_query", stub_embed_query, + ) + + from slayer.embeddings.service import EmbeddingService + emb_service = EmbeddingService(storage=storage) + persisted_models = [] + for m in _make_models(): + persisted = await storage.get_model( + m.name, data_source="warehouse", + ) + assert persisted is not None + persisted_models.append(persisted) + await emb_service.refresh_model_subtree(persisted) + await emb_service.refresh_datasource( + name="warehouse", models=persisted_models, + ) + for mem in await storage.list_memories(entities=None): + await emb_service.refresh_memory(mem) + + yield storage + + +@pytest_asyncio.fixture +async def service_with_embeddings( + storage_with_embeddings: YAMLStorage, +) -> SearchService: + return SearchService(storage=storage_with_embeddings) + + +@pytest.mark.asyncio +async def test_memories_invariant_under_max_entities_with_channel_3_active( + service_with_embeddings: SearchService, +) -> None: + base = await _ids( + service_with_embeddings, + question="amount paid refund revenue customer email", + datasource="warehouse", + max_memories=10, + max_example_queries=2, + max_entities=5, + ) + for max_entities in (0, 20, 50): + other = await _ids( + service_with_embeddings, + question="amount paid refund revenue customer email", + datasource="warehouse", + max_memories=10, + max_example_queries=2, + max_entities=max_entities, + ) + assert other["memories"] == base["memories"], ( + f"channel-3 active: memories changed when max_entities went " + f"5 -> {max_entities}" + ) + + +@pytest.mark.asyncio +async def test_entities_invariant_under_max_memories_with_channel_3_active( + service_with_embeddings: SearchService, +) -> None: + base = await _ids( + service_with_embeddings, + question="amount paid refund customer email warehouse", + datasource="warehouse", + max_memories=5, + max_example_queries=2, + max_entities=10, + ) + for max_memories in (0, 20, 50): + other = await _ids( + service_with_embeddings, + question="amount paid refund customer email warehouse", + datasource="warehouse", + max_memories=max_memories, + max_example_queries=2, + max_entities=10, + ) + assert other["entities"] == base["entities"], ( + f"channel-3 active: entities changed when max_memories went " + f"5 -> {max_memories}" + ) From 8acc5625ce212fee3444f0ae2d4fff070fa29710 Mon Sep 17 00:00:00 2001 From: Egor Kraev Date: Thu, 14 May 2026 15:36:55 +0200 Subject: [PATCH 2/3] DEV-1414: address codex review (perf, stale rows, deterministic test hash) - HIGH _backfill_memory_by_id was O(N^2) under the new per-kind full rankings (linear scan of all_memories per id, 3x). Build a ``{m.id: m}`` dict once in ``SearchService.search`` and pass it in; per-call backfill is now O(N). - MEDIUM channel 3 ranked every sidecar row, including stale ones (memories deleted from storage, hidden / removed entities still embedded). Those rows consumed cosine ranks and degraded live docs' RRF scores (still invariant under cap changes, but lossy). Filter rows against ``corpus.canonical_to_kind`` before partitioning so the embedding candidate set matches the live tantivy corpus. - LOW invariance-test stub used Python ``hash()`` for deterministic vectors, which is randomised per interpreter. Switch to ``hashlib.sha256`` so the rankings are reproducible across runs. Co-Authored-By: Claude Opus 4.7 (1M context) --- slayer/search/service.py | 36 ++++++++++++++++++++++++++------- tests/test_search_invariance.py | 16 +++++++++++---- 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/slayer/search/service.py b/slayer/search/service.py index 5d3dd67..5a9b46c 100644 --- a/slayer/search/service.py +++ b/slayer/search/service.py @@ -146,15 +146,20 @@ def _dedup(items: List[str]) -> List[str]: def _backfill_memory_by_id( *, memory_by_id: dict, - all_memories: List["Memory"], + all_memories_by_id: dict[int, "Memory"], mem_ids, ) -> None: """For each id in ``mem_ids`` not already in ``memory_by_id``, look it - up in ``all_memories`` and insert it. Mutates ``memory_by_id``.""" + up in ``all_memories_by_id`` and insert it. Mutates ``memory_by_id``. + + Takes a precomputed id→Memory dict (not the raw list) so per-call + backfill stays O(N) instead of O(N²) when every channel returns the + full memory corpus (DEV-1414). + """ for mem_id in mem_ids: if mem_id in memory_by_id: continue - mem = next((m for m in all_memories if m.id == mem_id), None) + mem = all_memories_by_id.get(mem_id) if mem is not None: memory_by_id[mem_id] = mem @@ -431,20 +436,22 @@ async def search( warnings = _dedup(warnings + channel_3_warnings) # Backfill memory_by_id from every channel so RRF can resolve - # any memory hit downstream. + # any memory hit downstream. Build the id→Memory dict once so + # the three backfills stay O(N) overall (DEV-1414). + all_memories_by_id = {m.id: m for m in all_memories} _backfill_memory_by_id( memory_by_id=memory_by_id, - all_memories=all_memories, + all_memories_by_id=all_memories_by_id, mem_ids=channel_1_memory_ranking, ) _backfill_memory_by_id( memory_by_id=memory_by_id, - all_memories=all_memories, + all_memories_by_id=all_memories_by_id, mem_ids=index_hits_by_memory_id.keys(), ) _backfill_memory_by_id( memory_by_id=memory_by_id, - all_memories=all_memories, + all_memories_by_id=all_memories_by_id, mem_ids=channel_3_memory_ranking, ) @@ -662,6 +669,14 @@ async def _run_channel_3( * memory rows whose ``canonical_id`` appears in the supplied ``eligible_memory_canonicals`` set (already datasource-filtered upstream). + + DEV-1414: rows whose ``canonical_id`` is not in the live tantivy + corpus (stale memory ids, hidden / deleted entities) are dropped + before the matrix build. Otherwise stale rows would consume + cosine rank positions and degrade live docs' RRF scores — + invariant under cap changes (so the per-bucket contract still + holds) but surprising and lossy. The filter keeps the channel's + candidate set aligned with channel 2's tantivy corpus. """ if not question_active or corpus is None: return [], [], [] @@ -685,6 +700,13 @@ async def _run_channel_3( datasource=datasource, eligible_memory_canonicals=eligible_memory_canonicals or set(), ) + # Drop sidecar rows that don't correspond to anything in the + # live tantivy corpus (DEV-1414). Memory rows are keyed + # ``memory:`` in storage and as the corpus's + # ``canonical_to_kind`` key; entity rows share the canonical + # string directly. Both shapes match by single dict lookup. + live_canonicals = corpus.canonical_to_kind.keys() + rows = [r for r in rows if r.canonical_id in live_canonicals] if not rows: return [], [], [ f"embedding channel skipped: no embedding rows for model " diff --git a/tests/test_search_invariance.py b/tests/test_search_invariance.py index 62c8b06..521206a 100644 --- a/tests/test_search_invariance.py +++ b/tests/test_search_invariance.py @@ -17,6 +17,7 @@ from __future__ import annotations +import hashlib import tempfile from typing import AsyncIterator, List, Optional @@ -477,11 +478,18 @@ async def storage_with_embeddings( monkeypatch.setattr(embedding_client, "is_available", lambda: True) # Deterministic embeddings: hash the rendered text into a tiny - # vector so ranks vary across docs but are reproducible. + # vector so ranks vary across docs but are reproducible across + # interpreter runs (Python's built-in ``hash`` is randomised + # per process, so use sha256 here). def _vec(text: str) -> List[float]: - return [ - (hash((text, i)) & 0xFFFF) / 65535.0 for i in range(8) - ] + out: List[float] = [] + for i in range(8): + digest = hashlib.sha256( + f"{text}|{i}".encode("utf-8"), + ).digest() + # First two bytes give a stable 16-bit unsigned int. + out.append(((digest[0] << 8) | digest[1]) / 65535.0) + return out async def stub_embed_batch( # NOSONAR(S7503) — stub matches embed_batch async signature texts: List[str], *, model: Optional[str] = None, From ec7223a79c0045eba1ff8ba31052a47ba6867463 Mon Sep 17 00:00:00 2001 From: Egor Kraev Date: Thu, 14 May 2026 15:56:47 +0200 Subject: [PATCH 3/3] DEV-1414: extract helpers to clear Sonar cognitive-complexity gate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three S3776 issues opened by the DEV-1414 changes (cognitive complexity 17, 19, 18 vs the 15 threshold). All reduce to "the function grew a few branches; extract a helper": - slayer/search/index.py: factor the kind-filter / exclude-kind boolean wrap-up into _apply_kind_filter so search_index reads as parse → wrap → search. - slayer/search/service.py: factor the per-kind cosine ranking into a module-level _rank_embedding_kind so _run_channel_3 calls it twice instead of inlining two near-identical matrix+top_k blocks. - tests/test_search_invariance.py: factor the topic→entities decision tree out of _seed_invariance_corpus into _entities_for_topic (CodeRabbit nitpicks #1 and #2 — same suggestion via both review batches). Co-Authored-By: Claude Opus 4.7 (1M context) --- slayer/search/index.py | 49 ++++++++++++++-------- slayer/search/service.py | 73 ++++++++++++++++++++------------- tests/test_search_invariance.py | 48 +++++++++++----------- 3 files changed, 101 insertions(+), 69 deletions(-) diff --git a/slayer/search/index.py b/slayer/search/index.py index 88cbf57..489475d 100644 --- a/slayer/search/index.py +++ b/slayer/search/index.py @@ -242,6 +242,31 @@ def build_in_memory_corpus( # --------------------------------------------------------------------------- +def _apply_kind_filter( + *, + query: "tantivy.Query", + schema: "tantivy.Schema", + kind_filter: Optional[str], + exclude_kind: Optional[str], +) -> "tantivy.Query": + """Wrap ``query`` in a boolean query that ``Must`` includes (or + ``MustNot`` excludes) docs whose ``kind`` field exactly equals the + supplied value. Returns ``query`` unchanged when neither argument + is set. The caller has already validated mutual exclusivity.""" + if kind_filter is None and exclude_kind is None: + return query + target = kind_filter if kind_filter is not None else exclude_kind + occur = ( + tantivy.Occur.Must if kind_filter is not None + else tantivy.Occur.MustNot + ) + kind_term = tantivy.Query.term_query(schema, "kind", target) + return tantivy.Query.boolean_query([ + (tantivy.Occur.Must, query), + (occur, kind_term), + ]) + + def search_index( *, index: tantivy.Index, @@ -285,24 +310,12 @@ def search_index( query = index.parse_query(question, fields) except (ValueError, RuntimeError): return [] - if kind_filter is not None or exclude_kind is not None: - schema = index.schema - if kind_filter is not None: - kind_term = tantivy.Query.term_query( - schema, "kind", kind_filter, - ) - query = tantivy.Query.boolean_query([ - (tantivy.Occur.Must, query), - (tantivy.Occur.Must, kind_term), - ]) - else: - kind_term = tantivy.Query.term_query( - schema, "kind", exclude_kind, - ) - query = tantivy.Query.boolean_query([ - (tantivy.Occur.Must, query), - (tantivy.Occur.MustNot, kind_term), - ]) + query = _apply_kind_filter( + query=query, + schema=index.schema, + kind_filter=kind_filter, + exclude_kind=exclude_kind, + ) searcher = index.searcher() raw_hits = searcher.search(query, limit).hits out: List[IndexHit] = [] diff --git a/slayer/search/service.py b/slayer/search/service.py index 5a9b46c..e43d8ba 100644 --- a/slayer/search/service.py +++ b/slayer/search/service.py @@ -301,6 +301,33 @@ def _memory_id_from_canonical(canonical_id: str) -> Optional[int]: return None +def _rank_embedding_kind( + *, + rows: List["Embedding"], + normalised_query, + np, + normalise_matrix, + top_k_cosine, +) -> List[str]: + """Rank one kind of embedding rows by cosine similarity to the + pre-normalised query vector. Returns the rows' ``canonical_id`` + strings in descending similarity order. Empty input → empty list. + + Pulls the per-kind matrix build + cosine call out of + ``SearchService._run_channel_3`` so each kind's ranking is a single + line in the caller (DEV-1414 — keeps channel 3 below the + cognitive-complexity gate).""" + if not rows: + return [] + matrix = np.array([r.embedding for r in rows], dtype=np.float32) + pairs = top_k_cosine( + query=normalised_query, + matrix=normalise_matrix(matrix), + k=len(rows), + ) + return [rows[idx].canonical_id for idx, _score in pairs] + + def _fuse_entity_hits( *, rankings: List[List[str]], @@ -742,35 +769,25 @@ async def _run_channel_3( memory_rows = [r for r in rows if r.entity_kind == "memory"] entity_rows = [r for r in rows if r.entity_kind != "memory"] normalised_query = normalise(query_vec) - + ranked_memory_canonicals = _rank_embedding_kind( + rows=memory_rows, + normalised_query=normalised_query, + np=np, + normalise_matrix=normalise_matrix, + top_k_cosine=top_k_cosine, + ) memory_ranking: List[int] = [] - if memory_rows: - memory_matrix = np.array( - [r.embedding for r in memory_rows], dtype=np.float32, - ) - for idx, _score in top_k_cosine( - query=normalised_query, - matrix=normalise_matrix(memory_matrix), - k=len(memory_rows), - ): - memory_id = _memory_id_from_canonical( - memory_rows[idx].canonical_id, - ) - if memory_id is not None: - memory_ranking.append(memory_id) - - entity_ranking: List[str] = [] - if entity_rows: - entity_matrix = np.array( - [r.embedding for r in entity_rows], dtype=np.float32, - ) - for idx, _score in top_k_cosine( - query=normalised_query, - matrix=normalise_matrix(entity_matrix), - k=len(entity_rows), - ): - entity_ranking.append(entity_rows[idx].canonical_id) - + for canonical in ranked_memory_canonicals: + memory_id = _memory_id_from_canonical(canonical) + if memory_id is not None: + memory_ranking.append(memory_id) + entity_ranking = _rank_embedding_kind( + rows=entity_rows, + normalised_query=normalised_query, + np=np, + normalise_matrix=normalise_matrix, + top_k_cosine=top_k_cosine, + ) return memory_ranking, entity_ranking, [] async def _collect_index_corpus( diff --git a/tests/test_search_invariance.py b/tests/test_search_invariance.py index 521206a..ecd0512 100644 --- a/tests/test_search_invariance.py +++ b/tests/test_search_invariance.py @@ -151,6 +151,29 @@ def _make_models() -> List[SlayerModel]: ] +def _entities_for_topic(topic: str) -> List[str]: + """Pick canonical entity tags for a learning-topic string. Pulled + out of ``_seed_invariance_corpus`` so each branch stays separate + from the seeding loop's control flow.""" + if "amount_paid" in topic or "paid" in topic or "revenue" in topic: + return ["warehouse.orders.amount_paid"] + if "email" in topic or "anonymous" in topic: + return ["warehouse.customers.email"] + if "ship" in topic or "warehouse" in topic: + return ["warehouse.warehouses"] + if "customer" in topic and "tier" in topic: + return ["warehouse.customers.customer_tier"] + if "customer" in topic: + return ["warehouse.customers"] + if "status" in topic: + return ["warehouse.orders.status"] + if "discount" in topic: + return ["warehouse.orders.discount_code"] + if "checkout" in topic or "fraud" in topic: + return ["warehouse.orders"] + return ["warehouse"] + + async def _seed_invariance_corpus(storage: StorageBackend) -> None: """Seed a corpus large enough to exercise the bottom-cliff cases that used to leak through the shared over_fetch budget.""" @@ -160,32 +183,11 @@ async def _seed_invariance_corpus(storage: StorageBackend) -> None: for model in _make_models(): await storage.save_model(model) - # 20 learning-only memories. + # 20+ learning-only memories tagged by topic. for i, topic in enumerate(_LEARNING_TOPICS): - # Spread entity tags so different memories surface for different - # questions. - entities: List[str] - if "amount_paid" in topic or "paid" in topic or "revenue" in topic: - entities = ["warehouse.orders.amount_paid"] - elif "email" in topic or "anonymous" in topic: - entities = ["warehouse.customers.email"] - elif "ship" in topic or "warehouse" in topic: - entities = ["warehouse.warehouses"] - elif "customer" in topic and "tier" in topic: - entities = ["warehouse.customers.customer_tier"] - elif "customer" in topic: - entities = ["warehouse.customers"] - elif "status" in topic: - entities = ["warehouse.orders.status"] - elif "discount" in topic: - entities = ["warehouse.orders.discount_code"] - elif "checkout" in topic or "fraud" in topic: - entities = ["warehouse.orders"] - else: - entities = ["warehouse"] await storage.save_memory( learning=f"KB{i:02d}: {topic}.", - entities=entities, + entities=_entities_for_topic(topic), ) # 8 query-bearing memories — drive the example_queries bucket.