diff --git a/src/surreal_memory/core/fiber.py b/src/surreal_memory/core/fiber.py index d1a68c73..f3d337b9 100755 --- a/src/surreal_memory/core/fiber.py +++ b/src/surreal_memory/core/fiber.py @@ -7,6 +7,7 @@ from typing import Any from uuid import uuid4 +from surreal_memory.utils.tag_normalizer import normalize_tags_lower from surreal_memory.utils.timeutils import utcnow @@ -125,6 +126,10 @@ def create( if tags is not None and not auto_tags and not agent_tags: effective_agent = tags + # Normalize case at ingestion boundary so storage is always lowercase + effective_auto = normalize_tags_lower(effective_auto) + effective_agent = normalize_tags_lower(effective_agent) + return cls( id=fiber_id or str(uuid4()), neuron_ids=neuron_ids, diff --git a/src/surreal_memory/engine/brain_transplant.py b/src/surreal_memory/engine/brain_transplant.py index bf180f83..0a67e8a0 100755 --- a/src/surreal_memory/engine/brain_transplant.py +++ b/src/surreal_memory/engine/brain_transplant.py @@ -79,9 +79,14 @@ def _fiber_matches_tags( fiber: dict[str, Any], required_tags: frozenset[str], ) -> bool: - """Return True if the fiber carries ANY of the required tags.""" - fiber_tags = set(fiber.get("tags", [])) - return bool(fiber_tags & required_tags) + """Return True if the fiber carries ANY of the required tags. + + Both the fiber's tags and the query tags are lowercased before comparison + so that "KB" matches a fiber tagged "kb" regardless of original casing. + """ + fiber_tags = {t.lower() for t in fiber.get("tags", [])} + normalized_required = frozenset(t.lower() for t in required_tags) + return bool(fiber_tags & normalized_required) def _fiber_matches_salience( diff --git a/src/surreal_memory/mcp/tool_handler_utils.py b/src/surreal_memory/mcp/tool_handler_utils.py index 20aa2f8e..924d68fb 100755 --- a/src/surreal_memory/mcp/tool_handler_utils.py +++ b/src/surreal_memory/mcp/tool_handler_utils.py @@ -22,13 +22,18 @@ def _parse_tags(args: dict[str, Any], *, max_items: int = _MAX_RECALL_TAGS) -> set[str] | None: """Parse and validate tags from MCP tool arguments. + Tag strings are lowercased so that callers querying "KB" match fibers + stored as "kb" and vice-versa (case-insensitive tag matching). + Returns a set of valid tag strings, or None if no valid tags provided. """ + from surreal_memory.utils.tag_normalizer import normalize_tags_lower + raw_tags = args.get("tags") if not raw_tags or not isinstance(raw_tags, list): return None tags = {t for t in raw_tags[:max_items] if isinstance(t, str) and 0 < len(t) <= _MAX_TAG_LENGTH} - return tags or None + return normalize_tags_lower(tags) or None def _require_brain_id(storage: NeuralStorage) -> str: diff --git a/src/surreal_memory/storage/surrealdb/store.py b/src/surreal_memory/storage/surrealdb/store.py index 4473e7fd..83092443 100755 --- a/src/surreal_memory/storage/surrealdb/store.py +++ b/src/surreal_memory/storage/surrealdb/store.py @@ -900,7 +900,9 @@ async def find_fibers( # Post-filter for complex conditions if tags: - fibers = [f for f in fibers if tags.issubset(f.tags)] + # Normalize query tags to lowercase so "KB" matches fibers stored as "kb" + normalized_tags = {t.lower() for t in tags} + fibers = [f for f in fibers if normalized_tags.issubset(f.tags)] if time_overlaps: start, end = time_overlaps # Normalize to naive UTC for comparison diff --git a/src/surreal_memory/utils/tag_normalizer.py b/src/surreal_memory/utils/tag_normalizer.py index 0d4e76c1..c981cb27 100755 --- a/src/surreal_memory/utils/tag_normalizer.py +++ b/src/surreal_memory/utils/tag_normalizer.py @@ -3,6 +3,10 @@ Normalizes tags at ingestion time to prevent semantic drift. Uses a curated synonym map for common equivalences and falls back to SimHash fuzzy matching for near-duplicates. + +``normalize_tags_lower`` is the lightweight primitive used at every read/write +boundary to make tag matching case-insensitive without imposing the full +synonym-normalization pipeline. """ from __future__ import annotations @@ -11,6 +15,23 @@ from surreal_memory.utils.simhash import hamming_distance, simhash + +def normalize_tags_lower(tags: set[str]) -> set[str]: + """Return *tags* with every element lowercased. + + This is the lightweight primitive applied at every write and read boundary + to make tag matching case-insensitive. It does **not** perform synonym + mapping — use :class:`TagNormalizer` for that. + + Args: + tags: Arbitrary set of tag strings. + + Returns: + New set where each tag is ``tag.lower()``. + """ + return {t.lower() for t in tags} + + # Canonical tag → known synonyms/aliases SYNONYM_MAP: dict[str, list[str]] = { "frontend": ["ui", "client-side", "client side", "front-end", "front end"], diff --git a/tests/unit/test_case_insensitive_tags.py b/tests/unit/test_case_insensitive_tags.py new file mode 100644 index 00000000..5b0bd1b2 --- /dev/null +++ b/tests/unit/test_case_insensitive_tags.py @@ -0,0 +1,289 @@ +"""Tests for case-insensitive tag matching (follow-up #31). + +Covers all 4 normalization boundaries: +1. Fiber.create() — tags lowercased at ingestion +2. _parse_tags() — query tags lowercased at MCP boundary +3. SurrealDBStorage.find_fibers() post-filter — query tags lowercased +4. brain_transplant._fiber_matches_tags() — both sides lowercased + +The SurrealDB-backed test stubs out the DB connection so CI runs without +a live SurrealDB instance, exercising the real post-filter logic. +""" + +from __future__ import annotations + +import sys +from unittest.mock import AsyncMock, MagicMock + +import pytest + +# Stub the optional surrealdb dependency so store.py can be imported in CI. +if "surrealdb" not in sys.modules: + _fake_surrealdb = MagicMock() + sys.modules["surrealdb"] = _fake_surrealdb + sys.modules["surrealdb.errors"] = MagicMock() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_fiber_row(fiber_id: str, tags: list[str]) -> dict: + """Build a minimal SurrealDB-style fiber row dict.""" + rid = MagicMock() + rid.table_name = "fiber" + rid.id = fiber_id + return { + "id": rid, + "neuron_ids": ["n1"], + "synapse_ids": [], + "anchor_neuron_id": "n1", + "pathway": ["n1"], + "conductivity": 1.0, + "coherence": 0.0, + "salience": 0.5, + "frequency": 0, + "auto_tags": tags, + "agent_tags": [], + "metadata": {}, + "compression_tier": 0, + "pinned": False, + } + + +# --------------------------------------------------------------------------- +# Boundary 1: Fiber.create() normalizes at ingestion +# --------------------------------------------------------------------------- + + +class TestFiberCreateNormalizesTags: + """Tags passed to Fiber.create() must be stored lowercased.""" + + def test_uppercase_agent_tags_lowercased(self) -> None: + from surreal_memory.core.fiber import Fiber + + f = Fiber.create( + neuron_ids={"n1"}, + synapse_ids=set(), + anchor_neuron_id="n1", + tags={"KB", "Python", "BACKEND"}, + ) + assert f.agent_tags == {"kb", "python", "backend"} + + def test_uppercase_auto_tags_lowercased(self) -> None: + from surreal_memory.core.fiber import Fiber + + f = Fiber.create( + neuron_ids={"n1"}, + synapse_ids=set(), + anchor_neuron_id="n1", + auto_tags={"Frontend", "API"}, + agent_tags={"DOCS"}, + ) + assert f.auto_tags == {"frontend", "api"} + assert f.agent_tags == {"docs"} + + def test_mixed_case_tags_property(self) -> None: + from surreal_memory.core.fiber import Fiber + + f = Fiber.create( + neuron_ids={"n1"}, + synapse_ids=set(), + anchor_neuron_id="n1", + auto_tags={"Auth"}, + agent_tags={"SESSION"}, + ) + assert f.tags == {"auth", "session"} + + def test_already_lowercase_unchanged(self) -> None: + from surreal_memory.core.fiber import Fiber + + f = Fiber.create( + neuron_ids={"n1"}, + synapse_ids=set(), + anchor_neuron_id="n1", + tags={"kb", "python"}, + ) + assert f.agent_tags == {"kb", "python"} + + +# --------------------------------------------------------------------------- +# Boundary 2: _parse_tags() normalizes query tags +# --------------------------------------------------------------------------- + + +class TestParseTagsNormalizes: + """_parse_tags must lowercase all tag strings it returns.""" + + def test_uppercase_query_tags_lowercased(self) -> None: + from surreal_memory.mcp.tool_handler_utils import _parse_tags + + result = _parse_tags({"tags": ["KB", "Python", "BACKEND"]}) + assert result == {"kb", "python", "backend"} + + def test_mixed_case_lowercased(self) -> None: + from surreal_memory.mcp.tool_handler_utils import _parse_tags + + result = _parse_tags({"tags": ["CamelCase", "ALLCAPS", "lower"]}) + assert result == {"camelcase", "allcaps", "lower"} + + def test_empty_list_returns_none(self) -> None: + from surreal_memory.mcp.tool_handler_utils import _parse_tags + + assert _parse_tags({"tags": []}) is None + + def test_no_tags_key_returns_none(self) -> None: + from surreal_memory.mcp.tool_handler_utils import _parse_tags + + assert _parse_tags({}) is None + + +# --------------------------------------------------------------------------- +# Boundary 3: SurrealDBStorage.find_fibers() post-filter (case-insensitive) +# --------------------------------------------------------------------------- + + +class TestSurrealDBFindFibersTagCaseInsensitive: + """Querying with uppercase tags must match fibers stored with lowercase tags. + + The SurrealDB connection is stubbed — _query returns pre-built rows so that + the post-filter logic (which lives in pure Python) is exercised directly. + """ + + @pytest.mark.asyncio + async def test_uppercase_query_matches_lowercase_stored_tag(self) -> None: + from surreal_memory.storage.surrealdb.store import SurrealDBStorage + + storage = SurrealDBStorage() + storage.set_brain("brain1") + + rows = [_make_fiber_row("fiber-a", ["kb", "python"])] + storage._query = AsyncMock(return_value=rows) # type: ignore[method-assign] + + # Query with "KB" — stored as "kb" — must still match + results = await storage.find_fibers(tags={"KB"}) + assert len(results) == 1 + assert "kb" in results[0].tags + + @pytest.mark.asyncio + async def test_lowercase_query_matches_lowercase_stored_tag(self) -> None: + from surreal_memory.storage.surrealdb.store import SurrealDBStorage + + storage = SurrealDBStorage() + storage.set_brain("brain1") + + rows = [_make_fiber_row("fiber-b", ["kb"])] + storage._query = AsyncMock(return_value=rows) # type: ignore[method-assign] + + results = await storage.find_fibers(tags={"kb"}) + assert len(results) == 1 + + @pytest.mark.asyncio + async def test_mixed_case_query_matches_lowercase_stored_tag(self) -> None: + from surreal_memory.storage.surrealdb.store import SurrealDBStorage + + storage = SurrealDBStorage() + storage.set_brain("brain1") + + rows = [_make_fiber_row("fiber-c", ["kb"])] + storage._query = AsyncMock(return_value=rows) # type: ignore[method-assign] + + results = await storage.find_fibers(tags={"Kb"}) + assert len(results) == 1 + + @pytest.mark.asyncio + async def test_tag_mismatch_returns_empty(self) -> None: + from surreal_memory.storage.surrealdb.store import SurrealDBStorage + + storage = SurrealDBStorage() + storage.set_brain("brain1") + + rows = [_make_fiber_row("fiber-d", ["python"])] + storage._query = AsyncMock(return_value=rows) # type: ignore[method-assign] + + results = await storage.find_fibers(tags={"KB"}) + assert results == [] + + @pytest.mark.asyncio + async def test_all_three_case_variants_match(self) -> None: + """Upper, lower, and mixed case queries all match the same stored fiber. + + This is the core case-insensitive assertion: querying "KB", "kb", or "Kb" + must all match a fiber that was stored with tag "kb". + """ + from surreal_memory.storage.surrealdb.store import SurrealDBStorage + + storage = SurrealDBStorage() + storage.set_brain("brain1") + + rows = [_make_fiber_row("fiber-e", ["kb"])] + + for query_tag in ("KB", "kb", "Kb"): + storage._query = AsyncMock(return_value=rows) # type: ignore[method-assign] + results = await storage.find_fibers(tags={query_tag}) + assert len(results) == 1, f"Expected match for query tag {query_tag!r}" + + +# --------------------------------------------------------------------------- +# Boundary 4: brain_transplant._fiber_matches_tags() +# --------------------------------------------------------------------------- + + +class TestFiberMatchesTagsCaseInsensitive: + """_fiber_matches_tags must be case-insensitive on both sides.""" + + def test_uppercase_required_matches_lowercase_fiber_tag(self) -> None: + from surreal_memory.engine.brain_transplant import _fiber_matches_tags + + fiber = {"tags": ["kb"]} + assert _fiber_matches_tags(fiber, frozenset({"KB"})) + + def test_lowercase_required_matches_uppercase_fiber_tag(self) -> None: + from surreal_memory.engine.brain_transplant import _fiber_matches_tags + + fiber = {"tags": ["KB"]} + assert _fiber_matches_tags(fiber, frozenset({"kb"})) + + def test_mixed_case_both_sides(self) -> None: + from surreal_memory.engine.brain_transplant import _fiber_matches_tags + + fiber = {"tags": ["Python"]} + assert _fiber_matches_tags(fiber, frozenset({"PYTHON"})) + + def test_no_match_returns_false(self) -> None: + from surreal_memory.engine.brain_transplant import _fiber_matches_tags + + fiber = {"tags": ["python"]} + assert not _fiber_matches_tags(fiber, frozenset({"KB"})) + + def test_empty_fiber_tags(self) -> None: + from surreal_memory.engine.brain_transplant import _fiber_matches_tags + + fiber = {"tags": []} + assert not _fiber_matches_tags(fiber, frozenset({"kb"})) + + +# --------------------------------------------------------------------------- +# Boundary 0: normalize_tags_lower() utility itself +# --------------------------------------------------------------------------- + + +class TestNormalizeTagsLower: + """Basic contract for the normalize_tags_lower() primitive.""" + + def test_lowercases_all_tags(self) -> None: + from surreal_memory.utils.tag_normalizer import normalize_tags_lower + + assert normalize_tags_lower({"KB", "Python", "BACKEND"}) == {"kb", "python", "backend"} + + def test_idempotent(self) -> None: + from surreal_memory.utils.tag_normalizer import normalize_tags_lower + + tags = {"kb", "python"} + assert normalize_tags_lower(normalize_tags_lower(tags)) == tags + + def test_empty_set(self) -> None: + from surreal_memory.utils.tag_normalizer import normalize_tags_lower + + assert normalize_tags_lower(set()) == set()