Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/recallforge/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ class HybridResult:
memory_role: str = "root"
memory_root_path: Optional[str] = None
memory_hit_count: int = 1
tags: Optional[List[str]] = None
audit: Optional[SearchAudit] = None # Per-result audit trail


Expand Down Expand Up @@ -494,6 +495,7 @@ def _vector_results_to_hybrid(self, results: List[SearchResult]) -> List[HybridR
memory_id=getattr(result, "memory_id", None),
memory_role=getattr(result, "memory_role", "root"),
memory_root_path=getattr(result, "memory_root_path", None),
tags=getattr(result, "tags", None),
))
return hybrid_results

Expand Down Expand Up @@ -1178,6 +1180,7 @@ def _normalize(values: Dict[str, float], neutral: float = 0.5) -> Dict[str, floa
memory_id=getattr(result, "memory_id", None),
memory_role=getattr(result, "memory_role", "root"),
memory_root_path=getattr(result, "memory_root_path", None),
tags=getattr(result, "tags", None),
audit=audit,
))

Expand Down Expand Up @@ -1210,6 +1213,20 @@ def _roll_up_memory_hits(self, results: List[HybridResult]) -> List[HybridResult
if not results:
return []

def _merge_tags(items: List[HybridResult]) -> Optional[List[str]]:
merged: List[str] = []
seen: set[str] = set()
for item in items:
for tag in getattr(item, "tags", None) or []:
cleaned = str(tag or "").strip().lower()
if not cleaned or cleaned in seen:
continue
seen.add(cleaned)
merged.append(cleaned)
if len(merged) >= 8:
return merged
return merged or None

grouped: Dict[str, List[HybridResult]] = {}
order: List[str] = []
for result in results:
Expand All @@ -1224,6 +1241,7 @@ def _roll_up_memory_hits(self, results: List[HybridResult]) -> List[HybridResult
group = sorted(grouped[key], key=lambda item: item.score, reverse=True)
representative = group[0]
representative.memory_hit_count = len(group)
representative.tags = _merge_tags(group)
memory_rollup_boost = 1.0
if len(group) > 1:
memory_rollup_boost += min(0.15, 0.03 * (len(group) - 1))
Expand Down Expand Up @@ -1461,6 +1479,7 @@ class BatchSearchResult:
score: float # Best score across queries
source: str # Comma-separated list of query indices that found this result
query_scores: Dict[int, float] # Map of query_index -> score
tags: Optional[List[str]] = None


def search_batch(
Expand Down Expand Up @@ -1593,6 +1612,7 @@ def run_single_query(q: BatchQuery) -> List[tuple]:
score=data['rrf_score'],
source=','.join(str(i) for i in sorted(data['query_indices'])),
query_scores=data['query_scores'],
tags=getattr(result, "tags", None),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Merge tags across duplicate batch hits

search_batch deduplicates on filepath but keeps a single stored result object, and the new tags field is copied only from that one object. When the same filepath is returned by multiple queries with different tag sets (for example, hybrid mode with memory rollup vs. fts/vec mode), later query tags are discarded, and because futures complete asynchronously this can make returned tags nondeterministic across runs. The merged entry should combine tags from all contributing hits instead of taking only the first-seen result’s tags.

Useful? React with 👍 / 👎.

))

final_results.sort(key=lambda x: x.score, reverse=True)
Expand Down
5 changes: 5 additions & 0 deletions src/recallforge/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,7 @@ async def _handle_search(arguments: dict, backend, storage) -> list[TextContent]
"memory_role": getattr(r, "memory_role", "root"),
"memory_root_path": getattr(r, "memory_root_path", None),
"memory_hit_count": getattr(r, "memory_hit_count", 1),
"tags": getattr(r, "tags", None),
}
for r in results
],
Expand Down Expand Up @@ -1007,6 +1008,7 @@ async def _handle_explain_results(arguments: dict, backend, storage) -> list[Tex
"memory_role": getattr(r, "memory_role", "root"),
"memory_root_path": getattr(r, "memory_root_path", None),
"memory_hit_count": getattr(r, "memory_hit_count", 1),
"tags": getattr(r, "tags", None),
}

if r.audit:
Expand Down Expand Up @@ -1098,6 +1100,7 @@ async def _handle_search_fts(arguments: dict, storage) -> list[TextContent]:
"session_id": r.session_id,
"project_id": r.project_id,
"profile": r.profile,
"tags": getattr(r, "tags", None),
}
for r in results
],
Expand Down Expand Up @@ -1168,6 +1171,7 @@ async def _handle_search_vec(arguments: dict, backend, storage) -> list[TextCont
"session_id": r.session_id,
"project_id": r.project_id,
"profile": r.profile,
"tags": getattr(r, "tags", None),
}
for r in results
],
Expand Down Expand Up @@ -1253,6 +1257,7 @@ async def _handle_search_batch(arguments: dict, backend, storage) -> list[TextCo
"session_id": getattr(r, "session_id", None),
"project_id": getattr(r, "project_id", None),
"profile": getattr(r, "profile", None),
"tags": getattr(r, "tags", None),
}
for r in results
],
Expand Down
1 change: 1 addition & 0 deletions src/recallforge/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class SearchResult:
memory_id: Optional[str] = None
memory_role: str = "root"
memory_root_path: Optional[str] = None
tags: Optional[List[str]] = None


@dataclass
Expand Down
117 changes: 117 additions & 0 deletions src/recallforge/storage/indexing_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import fnmatch
import hashlib
import json
import logging
import os
import re
import shutil
import subprocess
import time
Expand Down Expand Up @@ -50,6 +52,13 @@ def _resolve_captioner(self, embed_func, method_name: str):
return candidate
return None

def _select_generation_backend(self, *embed_funcs):
"""Pick the first backend/function that exposes generate_text()."""
for embed_func in embed_funcs:
if self._resolve_captioner(embed_func, "generate_text"):
return embed_func
return None

def _describe_image(self, embed_func, image_path: str, enabled: bool) -> str:
if not enabled:
return ""
Expand Down Expand Up @@ -90,6 +99,98 @@ def _describe_video(self, embed_image_func, embed_video_func, video_path: str, f

return ""

def _normalize_media_tags(self, raw_tags: List[str], *, max_tags: int = 8) -> List[str]:
"""Normalize generated tag strings into a compact canonical tag list."""
normalized: List[str] = []
seen: set[str] = set()
stop_tags = {"image", "images", "video", "videos", "photo", "picture", "frame", "scene", "clip"}

for raw in raw_tags:
tag = re.sub(r"\s+", " ", str(raw or "").strip().lower())
tag = tag.strip("\"'` ")
tag = re.sub(r"^\s*(?:[-*•]\s*|\d+[\.\)]\s*)", "", tag)
tag = re.sub(r"^[#\s]+", "", tag)
tag = tag.replace("_", " ").strip()
tag = re.sub(r"[;:,.]+$", "", tag).strip()
if not tag or tag in stop_tags:
continue
if len(tag) > 48:
truncated = tag[:48].rsplit(" ", 1)[0].strip()
tag = truncated or tag[:48].strip()
if not tag or tag in seen:
continue
seen.add(tag)
normalized.append(tag)
if len(normalized) >= max_tags:
break

return normalized

def _parse_generated_media_tags(self, raw: str) -> List[str]:
"""Parse tag generation output from JSON, newline, or comma-separated text."""
text = str(raw or "").strip()
if not text:
return []

candidates: List[str] = []
if text.startswith("[") and text.endswith("]"):
try:
payload = json.loads(text)
if isinstance(payload, list):
candidates.extend(str(item) for item in payload)
except json.JSONDecodeError:
pass
elif text.startswith("{") and text.endswith("}"):
try:
payload = json.loads(text)
if isinstance(payload, dict) and isinstance(payload.get("tags"), list):
candidates.extend(str(item) for item in payload["tags"])
except json.JSONDecodeError:
pass

if not candidates:
for line in (line.strip() for line in text.splitlines() if line.strip()):
lowered = line.lower()
if lowered.startswith("tags:"):
line = line.split(":", 1)[1]
if "," in line:
candidates.extend(part.strip() for part in line.split(",") if part.strip())
else:
candidates.append(line)
Comment on lines +152 to +159
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Strip code fences before splitting generated tags

If the model returns fenced JSON (for example ```json ... ```), the JSON parse path is skipped and the fallback line splitter treats fence tokens and bracket fragments as tags (e.g. json, ["cat"). This produces malformed tags that get persisted and later surfaced in search/memory APIs, degrading retrieval quality for media memories. The parser should remove markdown fences (or extract the JSON payload) before the line/comma fallback.

Useful? React with 👍 / 👎.


return self._normalize_media_tags(candidates)

def _generate_media_tags(self, embed_func, source_text: str, media_kind: str) -> List[str]:
"""Generate a normalized tag set using the lightweight text generator."""
source = re.sub(r"\s+", " ", str(source_text or "").strip())
if not source:
return []

generator = self._resolve_captioner(embed_func, "generate_text")
if not generator:
return []

prompt = (
f"Generate 3 to 8 retrieval-friendly tags for this {media_kind} memory.\n"
"Rules:\n"
"- Return only a JSON array of strings\n"
"- Use lowercase short noun phrases\n"
"- Avoid duplicates\n"
"- Avoid speculation or uncertain details\n"
"- No full sentences\n\n"
f"Description:\n{source[:1200]}"
)
try:
raw = generator(prompt, max_tokens=96) or ""
except Exception as exc:
logger.warning("index_%s: tag generation failed: %s", media_kind, exc)
return []

tags = self._parse_generated_media_tags(raw)
if not tags:
logger.debug("index_%s: tag generation returned no usable tags", media_kind)
return tags

def index_document(
self,
path: str,
Expand Down Expand Up @@ -940,6 +1041,7 @@ def index_image(
caption_media: bool = True,
memory_role: str = "root",
memory_root_path: Optional[str] = None,
inherited_tags: Optional[List[str]] = None,
) -> str:
"""
Index an image file.
Expand Down Expand Up @@ -1008,6 +1110,11 @@ def index_image(

vector = embed_func(actual_path)
image_caption = self._describe_image(embed_func, actual_path, enabled=caption_media)
image_tags = (
self._generate_media_tags(embed_func, image_caption, "image")
if caption_media and memory_role == "root"
else list(inherited_tags or [])
)
self._backend.insert_embedding(
content_hash=content_hash,
seq=0,
Expand All @@ -1025,6 +1132,7 @@ def index_image(
profile=profile,
memory_role=memory_role,
memory_root_path=memory_root_path,
tags=image_tags or None,
)

# Schedule debounced FTS rebuild
Expand Down Expand Up @@ -1101,6 +1209,12 @@ def index_video(
)
parts = [part for part in (video_caption, transcript_summary) if part]
video_body = "\n\n".join(parts)[:4000]
video_tag_backend = self._select_generation_backend(embed_video_func, embed_image_func)
video_tags = (
self._generate_media_tags(video_tag_backend, video_body, "video")
if caption_media
Comment on lines +1213 to +1215
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Fall back to image generator when video generator is missing

Video tag generation currently passes embed_video_func or embed_image_func as a single candidate. When embed_video_func is present but does not implement generate_text (or is a plain callable wrapper), _generate_media_tags returns no tags and never tries embed_image_func, even if the image backend can generate tags. This silently drops derived tags for video roots in mixed-backend setups.

Useful? React with 👍 / 👎.

else []
)

try:
modified_at = int(os.path.getmtime(actual_path) * 1000)
Expand Down Expand Up @@ -1147,6 +1261,7 @@ def index_video(
profile=profile,
memory_role="root",
memory_root_path=logical_path,
tags=video_tags or None,
)
indexed_video_embeddings = 1
except Exception as e:
Expand Down Expand Up @@ -1174,6 +1289,7 @@ def index_video(
caption_media=caption_media,
memory_role="child",
memory_root_path=logical_path,
inherited_tags=video_tags or None,
)
indexed_frames += 1

Expand All @@ -1190,6 +1306,7 @@ def index_video(
profile=profile,
memory_role="child",
memory_root_path=logical_path,
tags=video_tags or None,
)
indexed_transcripts += 1

Expand Down
Loading
Loading