diff --git a/src/recallforge/search.py b/src/recallforge/search.py index 915ccff..a3b87ce 100644 --- a/src/recallforge/search.py +++ b/src/recallforge/search.py @@ -157,6 +157,7 @@ class HybridResult: memory_role: str = "root" memory_root_path: Optional[str] = None memory_hit_count: int = 1 + tags: Optional[List[str]] = None audit: Optional[SearchAudit] = None # Per-result audit trail @@ -494,6 +495,7 @@ def _vector_results_to_hybrid(self, results: List[SearchResult]) -> List[HybridR memory_id=getattr(result, "memory_id", None), memory_role=getattr(result, "memory_role", "root"), memory_root_path=getattr(result, "memory_root_path", None), + tags=getattr(result, "tags", None), )) return hybrid_results @@ -1178,6 +1180,7 @@ def _normalize(values: Dict[str, float], neutral: float = 0.5) -> Dict[str, floa memory_id=getattr(result, "memory_id", None), memory_role=getattr(result, "memory_role", "root"), memory_root_path=getattr(result, "memory_root_path", None), + tags=getattr(result, "tags", None), audit=audit, )) @@ -1210,6 +1213,20 @@ def _roll_up_memory_hits(self, results: List[HybridResult]) -> List[HybridResult if not results: return [] + def _merge_tags(items: List[HybridResult]) -> Optional[List[str]]: + merged: List[str] = [] + seen: set[str] = set() + for item in items: + for tag in getattr(item, "tags", None) or []: + cleaned = str(tag or "").strip().lower() + if not cleaned or cleaned in seen: + continue + seen.add(cleaned) + merged.append(cleaned) + if len(merged) >= 8: + return merged + return merged or None + grouped: Dict[str, List[HybridResult]] = {} order: List[str] = [] for result in results: @@ -1224,6 +1241,7 @@ def _roll_up_memory_hits(self, results: List[HybridResult]) -> List[HybridResult group = sorted(grouped[key], key=lambda item: item.score, reverse=True) representative = group[0] representative.memory_hit_count = len(group) + representative.tags = _merge_tags(group) memory_rollup_boost = 1.0 if len(group) > 1: memory_rollup_boost += min(0.15, 0.03 * (len(group) - 1)) @@ -1461,6 +1479,7 @@ class BatchSearchResult: score: float # Best score across queries source: str # Comma-separated list of query indices that found this result query_scores: Dict[int, float] # Map of query_index -> score + tags: Optional[List[str]] = None def search_batch( @@ -1593,6 +1612,7 @@ def run_single_query(q: BatchQuery) -> List[tuple]: score=data['rrf_score'], source=','.join(str(i) for i in sorted(data['query_indices'])), query_scores=data['query_scores'], + tags=getattr(result, "tags", None), )) final_results.sort(key=lambda x: x.score, reverse=True) diff --git a/src/recallforge/server.py b/src/recallforge/server.py index b4b215e..4dc4e59 100644 --- a/src/recallforge/server.py +++ b/src/recallforge/server.py @@ -937,6 +937,7 @@ async def _handle_search(arguments: dict, backend, storage) -> list[TextContent] "memory_role": getattr(r, "memory_role", "root"), "memory_root_path": getattr(r, "memory_root_path", None), "memory_hit_count": getattr(r, "memory_hit_count", 1), + "tags": getattr(r, "tags", None), } for r in results ], @@ -1007,6 +1008,7 @@ async def _handle_explain_results(arguments: dict, backend, storage) -> list[Tex "memory_role": getattr(r, "memory_role", "root"), "memory_root_path": getattr(r, "memory_root_path", None), "memory_hit_count": getattr(r, "memory_hit_count", 1), + "tags": getattr(r, "tags", None), } if r.audit: @@ -1098,6 +1100,7 @@ async def _handle_search_fts(arguments: dict, storage) -> list[TextContent]: "session_id": r.session_id, "project_id": r.project_id, "profile": r.profile, + "tags": getattr(r, "tags", None), } for r in results ], @@ -1168,6 +1171,7 @@ async def _handle_search_vec(arguments: dict, backend, storage) -> list[TextCont "session_id": r.session_id, "project_id": r.project_id, "profile": r.profile, + "tags": getattr(r, "tags", None), } for r in results ], @@ -1253,6 +1257,7 @@ async def _handle_search_batch(arguments: dict, backend, storage) -> list[TextCo "session_id": getattr(r, "session_id", None), "project_id": getattr(r, "project_id", None), "profile": getattr(r, "profile", None), + "tags": getattr(r, "tags", None), } for r in results ], diff --git a/src/recallforge/storage/base.py b/src/recallforge/storage/base.py index ad62390..14e2d6b 100644 --- a/src/recallforge/storage/base.py +++ b/src/recallforge/storage/base.py @@ -39,6 +39,7 @@ class SearchResult: memory_id: Optional[str] = None memory_role: str = "root" memory_root_path: Optional[str] = None + tags: Optional[List[str]] = None @dataclass diff --git a/src/recallforge/storage/indexing_ops.py b/src/recallforge/storage/indexing_ops.py index 54447e0..1ecd753 100644 --- a/src/recallforge/storage/indexing_ops.py +++ b/src/recallforge/storage/indexing_ops.py @@ -2,8 +2,10 @@ import fnmatch import hashlib +import json import logging import os +import re import shutil import subprocess import time @@ -50,6 +52,13 @@ def _resolve_captioner(self, embed_func, method_name: str): return candidate return None + def _select_generation_backend(self, *embed_funcs): + """Pick the first backend/function that exposes generate_text().""" + for embed_func in embed_funcs: + if self._resolve_captioner(embed_func, "generate_text"): + return embed_func + return None + def _describe_image(self, embed_func, image_path: str, enabled: bool) -> str: if not enabled: return "" @@ -90,6 +99,98 @@ def _describe_video(self, embed_image_func, embed_video_func, video_path: str, f return "" + def _normalize_media_tags(self, raw_tags: List[str], *, max_tags: int = 8) -> List[str]: + """Normalize generated tag strings into a compact canonical tag list.""" + normalized: List[str] = [] + seen: set[str] = set() + stop_tags = {"image", "images", "video", "videos", "photo", "picture", "frame", "scene", "clip"} + + for raw in raw_tags: + tag = re.sub(r"\s+", " ", str(raw or "").strip().lower()) + tag = tag.strip("\"'` ") + tag = re.sub(r"^\s*(?:[-*•]\s*|\d+[\.\)]\s*)", "", tag) + tag = re.sub(r"^[#\s]+", "", tag) + tag = tag.replace("_", " ").strip() + tag = re.sub(r"[;:,.]+$", "", tag).strip() + if not tag or tag in stop_tags: + continue + if len(tag) > 48: + truncated = tag[:48].rsplit(" ", 1)[0].strip() + tag = truncated or tag[:48].strip() + if not tag or tag in seen: + continue + seen.add(tag) + normalized.append(tag) + if len(normalized) >= max_tags: + break + + return normalized + + def _parse_generated_media_tags(self, raw: str) -> List[str]: + """Parse tag generation output from JSON, newline, or comma-separated text.""" + text = str(raw or "").strip() + if not text: + return [] + + candidates: List[str] = [] + if text.startswith("[") and text.endswith("]"): + try: + payload = json.loads(text) + if isinstance(payload, list): + candidates.extend(str(item) for item in payload) + except json.JSONDecodeError: + pass + elif text.startswith("{") and text.endswith("}"): + try: + payload = json.loads(text) + if isinstance(payload, dict) and isinstance(payload.get("tags"), list): + candidates.extend(str(item) for item in payload["tags"]) + except json.JSONDecodeError: + pass + + if not candidates: + for line in (line.strip() for line in text.splitlines() if line.strip()): + lowered = line.lower() + if lowered.startswith("tags:"): + line = line.split(":", 1)[1] + if "," in line: + candidates.extend(part.strip() for part in line.split(",") if part.strip()) + else: + candidates.append(line) + + return self._normalize_media_tags(candidates) + + def _generate_media_tags(self, embed_func, source_text: str, media_kind: str) -> List[str]: + """Generate a normalized tag set using the lightweight text generator.""" + source = re.sub(r"\s+", " ", str(source_text or "").strip()) + if not source: + return [] + + generator = self._resolve_captioner(embed_func, "generate_text") + if not generator: + return [] + + prompt = ( + f"Generate 3 to 8 retrieval-friendly tags for this {media_kind} memory.\n" + "Rules:\n" + "- Return only a JSON array of strings\n" + "- Use lowercase short noun phrases\n" + "- Avoid duplicates\n" + "- Avoid speculation or uncertain details\n" + "- No full sentences\n\n" + f"Description:\n{source[:1200]}" + ) + try: + raw = generator(prompt, max_tokens=96) or "" + except Exception as exc: + logger.warning("index_%s: tag generation failed: %s", media_kind, exc) + return [] + + tags = self._parse_generated_media_tags(raw) + if not tags: + logger.debug("index_%s: tag generation returned no usable tags", media_kind) + return tags + def index_document( self, path: str, @@ -940,6 +1041,7 @@ def index_image( caption_media: bool = True, memory_role: str = "root", memory_root_path: Optional[str] = None, + inherited_tags: Optional[List[str]] = None, ) -> str: """ Index an image file. @@ -1008,6 +1110,11 @@ def index_image( vector = embed_func(actual_path) image_caption = self._describe_image(embed_func, actual_path, enabled=caption_media) + image_tags = ( + self._generate_media_tags(embed_func, image_caption, "image") + if caption_media and memory_role == "root" + else list(inherited_tags or []) + ) self._backend.insert_embedding( content_hash=content_hash, seq=0, @@ -1025,6 +1132,7 @@ def index_image( profile=profile, memory_role=memory_role, memory_root_path=memory_root_path, + tags=image_tags or None, ) # Schedule debounced FTS rebuild @@ -1101,6 +1209,12 @@ def index_video( ) parts = [part for part in (video_caption, transcript_summary) if part] video_body = "\n\n".join(parts)[:4000] + video_tag_backend = self._select_generation_backend(embed_video_func, embed_image_func) + video_tags = ( + self._generate_media_tags(video_tag_backend, video_body, "video") + if caption_media + else [] + ) try: modified_at = int(os.path.getmtime(actual_path) * 1000) @@ -1147,6 +1261,7 @@ def index_video( profile=profile, memory_role="root", memory_root_path=logical_path, + tags=video_tags or None, ) indexed_video_embeddings = 1 except Exception as e: @@ -1174,6 +1289,7 @@ def index_video( caption_media=caption_media, memory_role="child", memory_root_path=logical_path, + inherited_tags=video_tags or None, ) indexed_frames += 1 @@ -1190,6 +1306,7 @@ def index_video( profile=profile, memory_role="child", memory_root_path=logical_path, + tags=video_tags or None, ) indexed_transcripts += 1 diff --git a/src/recallforge/storage/lancedb_backend.py b/src/recallforge/storage/lancedb_backend.py index b1552ca..0d07bbc 100644 --- a/src/recallforge/storage/lancedb_backend.py +++ b/src/recallforge/storage/lancedb_backend.py @@ -7,6 +7,7 @@ import fnmatch import hashlib +import json import logging import math import os @@ -1120,7 +1121,7 @@ def _fetch_memory_summary_rows( rows = list( self._embeddings_table.search() .where(" AND ".join(embed_filters)) - .select(["file_path", "memory_root_path", "text_body", "pos", "seq", "memory_role"]) + .select(["file_path", "memory_root_path", "text_body", "pos", "seq", "memory_role", "tags"]) .limit(max(100, min(5000, len(unique_root_paths) * 40))) .to_list() ) @@ -1135,6 +1136,53 @@ def _fetch_memory_summary_rows( grouped.setdefault(key, []).append(row) return grouped + def _derive_memory_tags( + self, + snippet_rows: List[Dict[str, Any]], + ) -> Optional[List[str]]: + """Build a compact deduplicated tag set from stored embedding metadata.""" + if not snippet_rows: + return None + + ordered_rows = sorted( + snippet_rows, + key=lambda row: ( + 0 if (row.get("memory_role") or "child") == "root" else 1, + row.get("pos", 0) or 0, + row.get("seq", 0) or 0, + row.get("file_path", ""), + ), + ) + + tags: List[str] = [] + seen: set[str] = set() + for row in ordered_rows: + raw_tags = row.get("tags") + parsed: List[str] = [] + if isinstance(raw_tags, list): + parsed = [str(tag).strip().lower() for tag in raw_tags if str(tag).strip()] + elif isinstance(raw_tags, str) and raw_tags.strip(): + try: + payload = json.loads(raw_tags) + if isinstance(payload, list): + parsed = [str(tag).strip().lower() for tag in payload if str(tag).strip()] + except json.JSONDecodeError: + parsed = [ + part.strip().lower() + for part in raw_tags.split(",") + if part.strip() + ] + + for tag in parsed: + if tag in seen: + continue + seen.add(tag) + tags.append(tag) + if len(tags) >= 8: + return tags + + return tags or None + def _derive_memory_summary( self, root_row: Dict[str, Any], @@ -1330,6 +1378,9 @@ def list_memories( row, summary_rows_by_root.get(root_path, []), ), + "tags": self._derive_memory_tags( + summary_rows_by_root.get(root_path, []), + ), } ) return output @@ -1412,7 +1463,7 @@ def get_memory( snippet_rows = list( self._embeddings_table.search() .where(" AND ".join(embed_filters)) - .select(["file_path", "content_type", "text_body", "pos", "seq", "memory_role"]) + .select(["file_path", "content_type", "text_body", "pos", "seq", "memory_role", "tags"]) .limit(20) .to_list() ) @@ -1463,6 +1514,7 @@ def get_memory( "project_id": root_row.get("project_id"), "profile": root_row.get("profile"), "summary": summary, + "tags": self._derive_memory_tags(snippet_rows), "root_document": { "path": root_row.get("file_path"), "content_hash": root_row.get("content_hash"), diff --git a/src/recallforge/storage/search_ops.py b/src/recallforge/storage/search_ops.py index 66b0e87..91e5b4f 100644 --- a/src/recallforge/storage/search_ops.py +++ b/src/recallforge/storage/search_ops.py @@ -1,5 +1,6 @@ """Search operations service for LanceDB storage backend.""" +import json import math import re import time @@ -55,7 +56,7 @@ def _bm25_fallback( .select(["collection", "file_path", "content_hash", "content_type", "title", "text_body", "embedded_at", "modified_at", "user_id", "session_id", "project_id", "profile", "memory_id", "memory_role", - "memory_root_path", "expires_at"]) + "memory_root_path", "tags", "expires_at"]) .limit(row_limit) ) rows = builder.to_pandas() @@ -280,6 +281,24 @@ def _make_search_result(self, row: Dict[str, Any], score: float, source: str) -> # Fallback only when text_body is empty - lazy load for final output body = self._backend.get_content(content_hash) or "" + raw_tags = row.get("tags") + decoded_tags: Optional[List[str]] = None + if isinstance(raw_tags, list): + decoded_tags = [str(tag).strip().lower() for tag in raw_tags if str(tag).strip()] + elif isinstance(raw_tags, str) and raw_tags.strip(): + try: + payload = json.loads(raw_tags) + if isinstance(payload, list): + decoded_tags = [str(tag).strip().lower() for tag in payload if str(tag).strip()] + except json.JSONDecodeError: + decoded_tags = [ + part.strip().lower() + for part in raw_tags.split(",") + if part.strip() + ] + if decoded_tags == []: + decoded_tags = None + return SearchResult( filepath=f"recallforge://{collection}/{file_path}", display_path=f"{collection}/{file_path}", @@ -302,6 +321,7 @@ def _make_search_result(self, row: Dict[str, Any], score: float, source: str) -> memory_id=memory_id, memory_role=memory_role, memory_root_path=memory_root_path, + tags=decoded_tags, ) def _get_ttl_filter(self) -> str: diff --git a/src/recallforge/watch_folder.py b/src/recallforge/watch_folder.py index 19b4600..4594065 100644 --- a/src/recallforge/watch_folder.py +++ b/src/recallforge/watch_folder.py @@ -18,7 +18,7 @@ from dataclasses import asdict, dataclass, field from pathlib import Path from queue import Empty, Queue -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from .documents import is_document_file from .video import is_video_file @@ -158,15 +158,19 @@ def _match(pattern: str) -> bool: or self._is_document_file(path) ) - def _build_snapshot(self, config: WatchConfig) -> Dict[str, float]: + def _build_snapshot(self, config: WatchConfig) -> Dict[str, Tuple[int, int, int]]: root = Path(config.folder_path).expanduser().resolve() - snap: Dict[str, float] = {} + snap: Dict[str, Tuple[int, int, int]] = {} for p in self._candidate_files(root, config.recursive): if not self._should_process(p, config): continue try: rel = p.resolve().relative_to(root).as_posix() - snap[rel] = p.stat().st_mtime + stat = p.stat() + # Some CI filesystems do not reliably advance mtime between fast + # successive writes, so include size in the snapshot to catch + # quick content changes during watch-folder polling. + snap[rel] = (stat.st_mtime_ns, stat.st_ctime_ns, stat.st_size) except Exception: logger.warning("Failed to get stats for %s", p) continue @@ -183,10 +187,10 @@ def _scanner_loop(self, watch_id: str) -> None: current = self._build_snapshot(config) # created/modified - for rel, mtime in current.items(): + for rel, state in current.items(): if rel not in prev: queue.put({"path": str(root / rel), "type": "created", "timestamp": time.time()}) - elif mtime > prev[rel]: + elif state != prev[rel]: queue.put({"path": str(root / rel), "type": "modified", "timestamp": time.time()}) # deleted diff --git a/tests/test_config_tools.py b/tests/test_config_tools.py index 9a91f2d..e762c1c 100644 --- a/tests/test_config_tools.py +++ b/tests/test_config_tools.py @@ -557,6 +557,7 @@ async def test_search_file_path_routes_through_text_query(self): result_item.memory_role = "root" result_item.memory_root_path = None result_item.memory_hit_count = 1 + result_item.tags = ["memory query", "markdown"] with tempfile.NamedTemporaryFile("w", suffix=".md", delete=False) as tmp: tmp.write("memory query from file path") @@ -571,6 +572,7 @@ async def test_search_file_path_routes_through_text_query(self): data = json.loads(result[0].text) self.assertEqual(data["file_path"], file_path) + self.assertEqual(data["results"][0]["tags"], ["memory query", "markdown"]) mock_searcher.search.assert_called_once() self.assertIn("memory query from file path", mock_searcher.search.call_args[0][0]) finally: diff --git a/tests/test_search_pipeline.py b/tests/test_search_pipeline.py index 53aa5b7..fcde078 100644 --- a/tests/test_search_pipeline.py +++ b/tests/test_search_pipeline.py @@ -27,7 +27,7 @@ # --------------------------------------------------------------------------- def _make_search_result(filepath: str, score: float = 0.9, source: str = "fts", - content_type: str = "text") -> SearchResult: + content_type: str = "text", tags: Optional[List[str]] = None) -> SearchResult: return SearchResult( filepath=filepath, display_path=filepath, @@ -43,6 +43,7 @@ def _make_search_result(filepath: str, score: float = 0.9, source: str = "fts", content_type=content_type, chunk_pos=0, body=f"Content of {filepath}", + tags=tags, ) @@ -437,6 +438,25 @@ def test_memory_rollup_keeps_singletons_unboosted(self): self.assertAlmostEqual(result.audit.memory_rollup_boost, 1.0) self.assertAlmostEqual(result.score, result.audit.final_blended_score) + def test_memory_rollup_merges_tags_from_sibling_assets(self): + searcher = HybridSearcher(backend=StubBackend(), storage=StubStorage(), limit=12) + + root = _make_search_result("memory_root.png", 0.9, "vec", "image", tags=["diagram"]) + child = _make_search_result( + "memory_child.txt", + 0.8, + "fts", + "text", + tags=["meeting notes", "diagram"], + ) + root.memory_id = "memory-tags" + child.memory_id = "memory-tags" + + rolled = searcher._roll_up_memory_hits(searcher._vector_results_to_hybrid([root, child])) + + self.assertEqual(len(rolled), 1) + self.assertEqual(rolled[0].tags, ["diagram", "meeting notes"]) + class TestParallelSearchTaskCapture(unittest.TestCase): def test_parallel_search_captures_original_vector(self): diff --git a/tests/test_storage.py b/tests/test_storage.py index 9a553fc..ac48b13 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -7,6 +7,7 @@ """ import hashlib +import json import os import shutil import sys @@ -1212,6 +1213,14 @@ def describe_video(self, path: str, frame_paths=None) -> str: frame_count = len(frame_paths or []) return f"Technical explainer video with {frame_count} keyframes showing diagrams." + def generate_text(self, prompt: str, max_tokens: int = 60) -> str: + prompt_lower = prompt.lower() + if "image memory" in prompt_lower: + return '["neural network", "diagram", "hidden layers"]' + if "video memory" in prompt_lower: + return '["technical explainer", "architecture diagram", "presentation"]' + return "[]" + class FailingCaptioningEmbedder(CaptioningEmbedder): def caption_image(self, path: str) -> str: @@ -1255,6 +1264,10 @@ def test_index_image_stores_caption_in_text_body(self): rows = self.backend._embeddings_table.search().where("content_type = 'image'").to_list() self.assertEqual(len(rows), 1) self.assertIn("Neural network diagram", rows[0].get("text_body") or "") + self.assertEqual( + json.loads(rows[0].get("tags") or "[]"), + ["neural network", "diagram", "hidden layers"], + ) def test_index_image_caption_failure_keeps_embedding(self): embedder = FailingCaptioningEmbedder() @@ -1299,10 +1312,38 @@ def test_index_video_stores_video_caption_in_text_body(self): video_rows = self.backend._embeddings_table.search().where("content_type = 'video'").to_list() self.assertEqual(len(video_rows), 1) self.assertIn("Technical explainer video", video_rows[0].get("text_body") or "") + self.assertEqual( + json.loads(video_rows[0].get("tags") or "[]"), + ["technical explainer", "architecture diagram", "presentation"], + ) + + def test_memory_lookup_surfaces_media_tags(self): + embedder = CaptioningEmbedder() + self.backend.index_image( + path=self.image_path, + collection="test", + embed_func=embedder, + caption_media=True, + ) + + memories = self.backend.list_memories(collection="test", limit=10) + self.assertEqual(len(memories), 1) + self.assertEqual( + memories[0]["tags"], + ["neural network", "diagram", "hidden layers"], + ) + + memory = self.backend.get_memory(path=str(Path(self.image_path).expanduser().resolve()), collection="test") + self.assertIsNotNone(memory) + self.assertEqual( + memory["tags"], + ["neural network", "diagram", "hidden layers"], + ) def test_index_video_keeps_parent_memory_and_links_children(self): embedder = CaptioningEmbedder() logical_path = str(Path(self.video_path).expanduser().resolve()) + expected_tags = ["technical explainer", "architecture diagram", "presentation"] fake_artifacts = SimpleNamespace( frames=[ SimpleNamespace( @@ -1350,6 +1391,17 @@ def failing_video_embed(_path: str): self.assertEqual(row.get("memory_role"), "child") self.assertEqual(row.get("memory_root_path"), logical_path) + child_embedding_rows = self.backend._embeddings_table.search().where( + f"collection = 'test' AND file_path LIKE '{logical_path}::%'" + ).to_list() + self.assertGreaterEqual(len(child_embedding_rows), 2) + for row in child_embedding_rows: + self.assertEqual(json.loads(row.get("tags") or "[]"), expected_tags) + + memory = self.backend.get_memory(path=logical_path, collection="test") + self.assertIsNotNone(memory) + self.assertEqual(memory["tags"], expected_tags) + def test_index_document_file_creates_root_memory_and_links_sections(self): document_path = os.path.join(self.temp_dir, "report.pdf") logical_path = str(Path(document_path).expanduser().resolve())