From 17943ae9fdbea4f142108683d86f73c1ed6a80df Mon Sep 17 00:00:00 2001 From: MollyAI Date: Fri, 27 Mar 2026 22:40:10 -0400 Subject: [PATCH 1/7] Recover expansion benchmark work and add MLX safety guardrails --- benchmarks/cross_modal_ablation.py | 160 ++++- docs/RELEASE.md | 19 +- docs/mcp-tools.md | 3 +- src/recallforge/__init__.py | 4 + src/recallforge/backends/mlx_backend.py | 837 +++++++++++++++-------- src/recallforge/search.py | 41 +- tests/test_cross_modal_benchmark_defs.py | 48 ++ tests/test_mlx_reranker_prompt.py | 187 +++++ tests/test_query_expansion.py | 4 + tests/test_search_pipeline.py | 63 ++ 10 files changed, 1039 insertions(+), 327 deletions(-) diff --git a/benchmarks/cross_modal_ablation.py b/benchmarks/cross_modal_ablation.py index 01e6d97..2bc3217 100644 --- a/benchmarks/cross_modal_ablation.py +++ b/benchmarks/cross_modal_ablation.py @@ -1742,6 +1742,82 @@ def ingest_corpus(backend, storage, collection: str, corpus_dir: Path) -> int: ] +@dataclass(frozen=True) +class ExpansionProfile: + """Benchmark-time query expansion configuration.""" + + name: str + expand: bool + enable_media_query_probe: bool + allow_generate_text: bool + + +EXPANSION_PROFILES: Dict[str, ExpansionProfile] = { + # Text queries: no expansion. + # Media queries: pure vector search with no caption/transcript BM25 probe. + "off": ExpansionProfile( + name="off", + expand=False, + enable_media_query_probe=False, + allow_generate_text=False, + ), + # Text queries: no expansion. + # Media queries: keep caption/transcript BM25 probe, but do not add expansion branches. + "caption_only": ExpansionProfile( + name="caption_only", + expand=False, + enable_media_query_probe=True, + allow_generate_text=False, + ), + # Text/media probe expansion uses heuristic fallback rules only. + "heuristic": ExpansionProfile( + name="heuristic", + expand=True, + enable_media_query_probe=True, + allow_generate_text=False, + ), + # Text/media probe expansion uses the backend's generator when available. + "qwen": ExpansionProfile( + name="qwen", + expand=True, + enable_media_query_probe=True, + allow_generate_text=True, + ), +} + + +class _ExpansionBackendProxy: + """Optionally hide generated expansion so benchmarks can force heuristics.""" + + def __init__(self, backend, *, allow_generate_text: bool): + self._backend = backend + self._allow_generate_text = allow_generate_text + + def generate_text(self, prompt: str, max_tokens: int = 60) -> str: + if not self._allow_generate_text: + raise NotImplementedError( + "generate_text disabled for this benchmark expansion profile" + ) + generator = getattr(self._backend, "generate_text", None) + if not callable(generator): + raise NotImplementedError("Underlying backend does not support generate_text") + return generator(prompt, max_tokens=max_tokens) + + def __getattr__(self, name: str): + return getattr(self._backend, name) + + +def _resolve_expansion_profile(profile_name: str) -> ExpansionProfile: + """Return the named query-expansion benchmark profile.""" + try: + return EXPANSION_PROFILES[profile_name] + except KeyError as exc: + valid = ", ".join(sorted(EXPANSION_PROFILES)) + raise ValueError( + f"Unknown expansion profile: {profile_name}. Valid choices: {valid}" + ) from exc + + def _group_queries( category_filters: Optional[List[str]] = None, max_queries_per_category: Optional[int] = None, @@ -1781,10 +1857,17 @@ def _select_stages(stage_filters: Optional[List[str]] = None) -> List[Tuple[str, return selected -def _resolve_output_path(output_path: Optional[str]) -> str: +def _resolve_output_path(output_path: Optional[str], expansion_profile: str) -> str: """Return the benchmark output path, falling back to the default results file.""" - return output_path or str( - PROJECT_ROOT / "benchmarks" / "results" / "cross_modal_ablation_results.json" + if output_path: + return output_path + + suffix = "" if expansion_profile == "caption_only" else f"_{expansion_profile}" + return str( + PROJECT_ROOT + / "benchmarks" + / "results" + / f"cross_modal_ablation_results{suffix}.json" ) @@ -1793,6 +1876,7 @@ def _build_output_payload( all_results: Dict[str, Dict[str, StageResult]], stages: List[Tuple[str, str]], *, + expansion_profile: ExpansionProfile, indexed_items: int, run_status: str, interrupted: bool, @@ -1806,6 +1890,12 @@ def _build_output_payload( "benchmark": "cross_modal_ablation", "version": __version__, "generated_at": datetime.now(timezone.utc).isoformat(), + "configuration": { + "expansion_profile": expansion_profile.name, + "expand_enabled": expansion_profile.expand, + "media_query_probe_enabled": expansion_profile.enable_media_query_probe, + "generated_query_expansion_enabled": expansion_profile.allow_generate_text, + }, "run_status": run_status, "interrupted": interrupted, "progress": { @@ -1937,12 +2027,18 @@ def run_search( collection: str, stage_mode: str, limit: int = 10, + expansion_profile: str = "caption_only", ) -> Tuple[List[Dict], float]: """Run a single search query and return results + latency_ms.""" from recallforge.search import HybridSearcher t0 = time.perf_counter() result_content_type = _result_content_type_for_category(gt.category) + profile = _resolve_expansion_profile(expansion_profile) + search_backend = _ExpansionBackendProxy( + backend, + allow_generate_text=profile.allow_generate_text, + ) if gt.query_type == "image" and gt.image_query_path: image_path = str(CORPUS_DIR / gt.image_query_path) @@ -1957,16 +2053,18 @@ def run_search( ) elif stage_mode in ("rrf", "hybrid"): searcher = HybridSearcher( - backend=backend, + backend=search_backend, storage=storage, limit=limit, collection=collection, content_type=result_content_type, + expand=profile.expand, + enable_media_query_probe=profile.enable_media_query_probe, ) - old_mode = backend.get_mode() - backend.set_mode("embed" if stage_mode == "rrf" else "hybrid") + old_mode = search_backend.get_mode() + search_backend.set_mode("embed" if stage_mode == "rrf" else "hybrid") results = searcher.search_image(image_path) - backend.set_mode(old_mode) + search_backend.set_mode(old_mode) elif stage_mode == "bm25": results = [] else: @@ -1985,16 +2083,18 @@ def run_search( ) elif stage_mode in ("rrf", "hybrid"): searcher = HybridSearcher( - backend=backend, + backend=search_backend, storage=storage, limit=limit, collection=collection, content_type=result_content_type, + expand=profile.expand, + enable_media_query_probe=profile.enable_media_query_probe, ) - old_mode = backend.get_mode() - backend.set_mode("embed" if stage_mode == "rrf" else "hybrid") + old_mode = search_backend.get_mode() + search_backend.set_mode("embed" if stage_mode == "rrf" else "hybrid") results = searcher.search_video(video_path) - backend.set_mode(old_mode) + search_backend.set_mode(old_mode) elif stage_mode == "bm25": results = [] else: @@ -2022,29 +2122,33 @@ def run_search( elif stage_mode == "rrf": # RRF without reranker (embed mode) searcher = HybridSearcher( - backend=backend, + backend=search_backend, storage=storage, limit=limit, collection=collection, content_type=result_content_type, + expand=profile.expand, + enable_media_query_probe=profile.enable_media_query_probe, ) - old_mode = backend.get_mode() - backend.set_mode("embed") + old_mode = search_backend.get_mode() + search_backend.set_mode("embed") results = searcher.search(gt.query) - backend.set_mode(old_mode) + search_backend.set_mode(old_mode) elif stage_mode == "hybrid": # Full hybrid with reranker searcher = HybridSearcher( - backend=backend, + backend=search_backend, storage=storage, limit=limit, collection=collection, content_type=result_content_type, + expand=profile.expand, + enable_media_query_probe=profile.enable_media_query_probe, ) - old_mode = backend.get_mode() - backend.set_mode("hybrid") + old_mode = search_backend.get_mode() + search_backend.set_mode("hybrid") results = searcher.search(gt.query) - backend.set_mode(old_mode) + search_backend.set_mode(old_mode) else: raise ValueError(f"Unknown stage mode: {stage_mode}") @@ -2087,6 +2191,7 @@ def run_benchmark( category_filters: Optional[List[str]] = None, stage_filters: Optional[List[str]] = None, max_queries_per_category: Optional[int] = None, + expansion_profile: str = "caption_only", ) -> Dict[str, Any]: """Run the full cross-modal ablation benchmark.""" @@ -2097,7 +2202,8 @@ def run_benchmark( if not categories: raise ValueError("No benchmark queries selected") stages = _select_stages(stage_filters) - save_path = _resolve_output_path(output_path) + profile = _resolve_expansion_profile(expansion_profile) + save_path = _resolve_output_path(output_path, profile.name) indexed = 0 completed_stages: List[str] = [] current_stage_name: Optional[str] = None @@ -2114,6 +2220,7 @@ def save_checkpoint( categories, all_results, stages, + expansion_profile=profile, indexed_items=indexed, run_status=run_status, interrupted=interrupted, @@ -2130,6 +2237,7 @@ def save_checkpoint( print(f"\nQuery categories: {', '.join(f'{k}({len(v)})' for k, v in categories.items())}") print(f"Stages: {', '.join(stage_name for stage_name, _ in stages)}") + print(f"Expansion profile: {profile.name}") print(f"Total queries: {sum(len(v) for v in categories.values())}") print(f"Corpus documents: {len(CORPUS_DOCS)}") @@ -2186,6 +2294,7 @@ def save_checkpoint( results, latency = run_search( backend, storage, gt, collection, effective_mode, + expansion_profile=profile.name, ) eval_detail = evaluate_results_detailed(results, gt, CORPUS_DIR) memory_metrics = eval_detail["memory"] @@ -2397,6 +2506,15 @@ def main(): default=None, help="Cap how many queries to run from each selected category", ) + parser.add_argument( + "--expansion-profile", + choices=sorted(EXPANSION_PROFILES), + default="caption_only", + help=( + "Query expansion profile: off (pure vector/media), caption_only " + "(current default), heuristic, or qwen" + ), + ) args = parser.parse_args() if args.dry_run: @@ -2410,6 +2528,7 @@ def main(): ) print(f"Total queries: {sum(len(v) for v in categories.values())}") print(f"Corpus documents: {len(CORPUS_DOCS)}") + print(f"Expansion profile: {args.expansion_profile}") print(f"\nQueries by category:") for cat, queries in categories.items(): @@ -2467,6 +2586,7 @@ def main(): category_filters=args.category, stage_filters=args.stage_mode, max_queries_per_category=args.max_queries_per_category, + expansion_profile=args.expansion_profile, ) finally: diff --git a/docs/RELEASE.md b/docs/RELEASE.md index 08d298c..3b00a5d 100644 --- a/docs/RELEASE.md +++ b/docs/RELEASE.md @@ -40,11 +40,28 @@ UAT_MCP_LIVE=1 .venv/bin/python -m pytest -q tests/uat/test_uat_comprehensive.py Then run the expanded benchmark: ```bash -.venv/bin/python benchmarks/cross_modal_ablation.py --backend mlx --output benchmarks/results/cross_modal_ablation_results.json +.venv/bin/python benchmarks/cross_modal_ablation.py --backend mlx --expansion-profile caption_only --output benchmarks/results/cross_modal_ablation_results.json ``` The benchmark now checkpoints to JSON as it runs. If the run is interrupted, the output file still contains partial results plus progress metadata. +For query-expansion release decisions, compare at least these profiles: + +```bash +.venv/bin/python benchmarks/cross_modal_ablation.py --backend mlx --expansion-profile caption_only +.venv/bin/python benchmarks/cross_modal_ablation.py --backend mlx --expansion-profile heuristic +.venv/bin/python benchmarks/cross_modal_ablation.py --backend mlx --expansion-profile qwen +``` + +Profile meanings: + +- `caption_only`: shipped default baseline for media queries. Text queries do not expand; image/video queries still use caption or transcript BM25 probes. +- `heuristic`: opt-in expansion branches using the legacy heuristic rewrite fallback. +- `qwen`: opt-in expansion branches using the backend `generate_text()` path when available. +- `off`: pure no-expansion baseline, including no media caption probe, useful for measuring the value of caption/transcript query text itself. + +When you omit `--output`, the benchmark now keeps profile-specific filenames for non-default runs, for example `cross_modal_ablation_results_qwen.json`. + ## 4. Tag and publish 1. Commit the release changes. diff --git a/docs/mcp-tools.md b/docs/mcp-tools.md index 7d5815c..2b74bff 100644 --- a/docs/mcp-tools.md +++ b/docs/mcp-tools.md @@ -149,7 +149,7 @@ Example MCP client config (Claude Desktop): | profile | string | No | — | Profile namespace filter | | intent | string (`exact_lookup`\|`semantic`\|`broad`) | No | — | Intent steering for RRF weights | | rerank_top_k | integer | No | 20 | Max top RRF candidates to rerank (`0` disables reranking) | -| expand | boolean | No | false | Enable VL-aware query expansion | +| expand | boolean | No | false | Enable opt-in query expansion. Text queries use Qwen-backed variants when the backend supports `generate_text()`, otherwise they fall back to heuristic rewrites. Image/video queries expand the generated caption/transcript probe text on the same rules. | \* Exactly one of `query`, `image_path`, or `video_path` must be provided. @@ -203,6 +203,7 @@ Example MCP client config (Claude Desktop): **Notes:** - Reuses the same retrieval pipeline as `search`, so explanations reflect the actual ranking path. +- `expand=true` is still opt-in. It adds extra retrieval branches, so expect a latency/quality tradeoff rather than a free win. - `provenance.rrf.sources` maps each contributing RRF list to that result’s rank in the list. - `provenance.reranker.scoring_path` shows whether the reranker used text or VL scoring. - `media_compensation_applied` is `true` for image/video candidates that received RRF compensation because BM25 cannot surface them structurally. diff --git a/src/recallforge/__init__.py b/src/recallforge/__init__.py index 17e4d57..f64d443 100644 --- a/src/recallforge/__init__.py +++ b/src/recallforge/__init__.py @@ -44,6 +44,10 @@ def _has_torch() -> bool: "RECALLFORGE_BACKEND": "Backend selector: auto | torch | mlx.", "RECALLFORGE_MODE": "Search mode: embed | hybrid.", "RECALLFORGE_MLX_QUANTIZE": "MLX quantization mode: bf16 | 4bit.", + "RECALLFORGE_MLX_HEAVY_OP_CONCURRENCY": "Concurrency ceiling for heavy MLX multimodal ops (default 1 for local safety).", + "RECALLFORGE_MLX_VIDEO_SAMPLE_FPS": "Sampling rate for MLX raw-video processing (lower is safer).", + "RECALLFORGE_MLX_VIDEO_MAX_FRAMES": "Frame cap for MLX raw-video processing (default tuned for local safety).", + "RECALLFORGE_MLX_VIDEO_FALLBACK_MAX_FRAMES": "Frame cap for ffmpeg frame-averaging fallback when native video embedding is unavailable.", "RECALLFORGE_STORAGE": "Storage backend selector (currently lancedb).", "RECALLFORGE_STORE_PATH": "Path to RecallForge data store.", "RECALLFORGE_TRACE": "Enable verbose MCP server trace logging (1=true).", diff --git a/src/recallforge/backends/mlx_backend.py b/src/recallforge/backends/mlx_backend.py index acfa43f..e5b11b8 100644 --- a/src/recallforge/backends/mlx_backend.py +++ b/src/recallforge/backends/mlx_backend.py @@ -14,8 +14,10 @@ import logging import importlib.util import threading +import tempfile import warnings import re +from contextlib import contextmanager from typing import List, Dict, Any, Optional import numpy as np @@ -75,6 +77,78 @@ class MLXEmbeddingError(RuntimeError): """Raised when the MLX embedding pipeline fails.""" +_HEAVY_OP_GATE_INIT_LOCK = threading.Lock() +_HEAVY_OP_GATE = None +_HEAVY_OP_GATE_LIMIT = None + + +class _HeavyOpGate: + """Serialize the heaviest MLX operations to avoid local memory pileups.""" + + def __init__(self, limit: int, lock_path: Optional[str] = None): + self.limit = max(1, int(limit)) + self.lock_path = lock_path or os.path.join( + tempfile.gettempdir(), "recallforge-mlx-heavy-op.lock" + ) + self._semaphore = threading.BoundedSemaphore(self.limit) + self._thread_state = threading.local() + + def _acquire_file_lock(self, op_name: str): + if self.limit != 1: + return None + try: + import fcntl + except ImportError: + return None + + handle = open(self.lock_path, "a+", encoding="utf-8") + logger.debug( + "mlx_heavy_op_wait_host_lock op=%s lock_path=%s", + op_name, + self.lock_path, + ) + fcntl.flock(handle.fileno(), fcntl.LOCK_EX) + return handle + + def _release_file_lock(self, handle) -> None: + if handle is None: + return + try: + import fcntl + fcntl.flock(handle.fileno(), fcntl.LOCK_UN) + except Exception: + pass + finally: + handle.close() + + @contextmanager + def hold(self, op_name: str): + depth = getattr(self._thread_state, "depth", 0) + acquired = False + file_handle = None + if depth == 0: + logger.debug("mlx_heavy_op_wait op=%s limit=%d", op_name, self.limit) + self._semaphore.acquire() + acquired = True + file_handle = self._acquire_file_lock(op_name) + logger.debug("mlx_heavy_op_acquired op=%s limit=%d", op_name, self.limit) + + self._thread_state.depth = depth + 1 + try: + yield + finally: + new_depth = getattr(self._thread_state, "depth", 1) - 1 + if new_depth <= 0: + if hasattr(self._thread_state, "depth"): + delattr(self._thread_state, "depth") + self._release_file_lock(file_handle) + if acquired: + self._semaphore.release() + logger.debug("mlx_heavy_op_release op=%s", op_name) + else: + self._thread_state.depth = new_depth + + class MLXBackend(ModelBackend): """ MLX-based model backend for Apple Silicon. @@ -101,10 +175,13 @@ class MLXBackend(ModelBackend): "Given a search query, retrieve relevant candidates that answer the query." ) - # Video sampling: 1 fps adapts to video length (30s video = 30 frames, - # 5min video = 300 frames). Max cap prevents OOM on very long videos. + # Video sampling: 1 fps adapts to video length (30s video = 30 frames). + # Keep the default cap conservative for local-agent safety; callers can + # raise it explicitly via env vars when they want heavier runs. _VIDEO_SAMPLE_FPS = 1.0 - _VIDEO_MAX_FRAMES = 128 + _VIDEO_MAX_FRAMES = 32 + _VIDEO_FALLBACK_MAX_FRAMES = 8 + _DEFAULT_HEAVY_OP_CONCURRENCY = 1 # Captioning descriptors removed — they produced captions too generic for BM25. # See REC-129 for dedicated captioning model support. @@ -170,6 +247,79 @@ def __init__( self.CAPTION_MODEL = os.environ.get( "RECALLFORGE_CAPTIONER_MODEL", self._DEFAULT_CAPTION_MODEL ) + self._VIDEO_SAMPLE_FPS = self._resolve_positive_float_env( + "RECALLFORGE_MLX_VIDEO_SAMPLE_FPS", + self._VIDEO_SAMPLE_FPS, + ) + self._VIDEO_MAX_FRAMES = self._resolve_positive_int_env( + "RECALLFORGE_MLX_VIDEO_MAX_FRAMES", + self._VIDEO_MAX_FRAMES, + ) + self._VIDEO_FALLBACK_MAX_FRAMES = min( + self._VIDEO_MAX_FRAMES, + self._resolve_positive_int_env( + "RECALLFORGE_MLX_VIDEO_FALLBACK_MAX_FRAMES", + self._VIDEO_FALLBACK_MAX_FRAMES, + ), + ) + + def _resolve_heavy_op_concurrency(self) -> int: + """Return the configured MLX heavy-op concurrency ceiling.""" + raw = os.environ.get( + "RECALLFORGE_MLX_HEAVY_OP_CONCURRENCY", + str(self._DEFAULT_HEAVY_OP_CONCURRENCY), + ).strip() + try: + value = int(raw) + except ValueError: + logger.warning( + "Invalid RECALLFORGE_MLX_HEAVY_OP_CONCURRENCY=%r; using %d", + raw, + self._DEFAULT_HEAVY_OP_CONCURRENCY, + ) + return self._DEFAULT_HEAVY_OP_CONCURRENCY + return max(1, value) + + def _resolve_positive_int_env(self, name: str, default: int) -> int: + """Read a positive integer env var with graceful fallback.""" + raw = os.environ.get(name) + if raw is None: + return default + try: + value = int(raw.strip()) + except ValueError: + logger.warning("Invalid %s=%r; using %d", name, raw, default) + return default + return value if value > 0 else default + + def _resolve_positive_float_env(self, name: str, default: float) -> float: + """Read a positive float env var with graceful fallback.""" + raw = os.environ.get(name) + if raw is None: + return default + try: + value = float(raw.strip()) + except ValueError: + logger.warning("Invalid %s=%r; using %.3f", name, raw, default) + return default + return value if value > 0 else default + + def _get_heavy_op_gate(self) -> _HeavyOpGate: + """Return the shared gate used to limit overlapping MLX heavy ops.""" + global _HEAVY_OP_GATE, _HEAVY_OP_GATE_LIMIT + + limit = self._resolve_heavy_op_concurrency() + with _HEAVY_OP_GATE_INIT_LOCK: + if _HEAVY_OP_GATE is None or _HEAVY_OP_GATE_LIMIT != limit: + _HEAVY_OP_GATE = _HeavyOpGate(limit) + _HEAVY_OP_GATE_LIMIT = limit + return _HEAVY_OP_GATE + + @contextmanager + def _hold_heavy_op(self, op_name: str): + """Serialize the heaviest MLX operations for local safety.""" + with self._get_heavy_op_gate().hold(op_name): + yield # ========================================================================= # Embedder @@ -286,6 +436,45 @@ def _format_text_prompt(self, text: str, system: str = None) -> str: text=text, ) + def _apply_chat_template( + self, + processor: Any, + messages: List[Dict[str, Any]], + *, + tokenize: bool = False, + add_generation_prompt: bool = True, + ) -> Any: + """Apply a chat template with fallback to tokenizer templates. + + Some current MLX/Qwen processor objects expose a tokenizer chat template + but raise when `processor.apply_chat_template(...)` is called directly. + Fall back to the tokenizer so the live MLX path remains compatible. + """ + apply_template = getattr(processor, "apply_chat_template", None) + if callable(apply_template): + try: + return apply_template( + messages, + tokenize=tokenize, + add_generation_prompt=add_generation_prompt, + ) + except ValueError as exc: + if "does not have a chat template" not in str(exc): + raise + except AttributeError: + pass + + tokenizer = getattr(processor, "tokenizer", None) + tokenizer_apply = getattr(tokenizer, "apply_chat_template", None) + if callable(tokenizer_apply): + return tokenizer_apply( + messages, + tokenize=tokenize, + add_generation_prompt=add_generation_prompt, + ) + + raise ValueError("Processor does not expose a usable chat template.") + def _resolve_max_text_tokens(self) -> int: """Resolve a safe tokenizer max length for truncation.""" tokenizer = getattr(self._embedder_processor, "tokenizer", None) @@ -531,116 +720,118 @@ def embed_images(self, image_paths: List[str]) -> np.ndarray: 5) `get_input_embeddings(...)` -> transformer forward with `inputs_embeds` 6) Last-token pool + L2 normalize + float32 numpy """ - image_paths = self._validate_image_paths(image_paths) - if not image_paths: - return np.empty((0, self._EMBED_DIM), dtype=np.float32) + with self._hold_heavy_op("embed_images"): + image_paths = self._validate_image_paths(image_paths) + if not image_paths: + return np.empty((0, self._EMBED_DIM), dtype=np.float32) - self._load_embedder() - num_layers = self._get_embedder_num_layers() + self._load_embedder() + num_layers = self._get_embedder_num_layers() - try: - from qwen_vl_utils import process_vision_info - except ImportError as exc: - raise MLXEmbeddingError( - "qwen-vl-utils vision dependencies are missing. " - "Install qwen-vl-utils and torchvision for image embeddings." - ) from exc - - messages_batch = [] - for path in image_paths: - messages_batch.append([{ - "role": "user", - "content": [ - {"type": "image", "image": path}, - {"type": "text", "text": "Describe this image."}, - ], - }]) - - try: - chat_texts = [ - self._embedder_processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True, - ) - for messages in messages_batch - ] - except Exception as exc: - raise MLXEmbeddingError( - "Failed to build chat templates for image embedding batch." - ) from exc + try: + from qwen_vl_utils import process_vision_info + except ImportError as exc: + raise MLXEmbeddingError( + "qwen-vl-utils vision dependencies are missing. " + "Install qwen-vl-utils and torchvision for image embeddings." + ) from exc - try: - image_inputs, _ = process_vision_info(messages_batch) - except Exception as exc: - raise MLXEmbeddingError( - "Failed to process vision inputs for image embedding batch." - ) from exc + messages_batch = [] + for path in image_paths: + messages_batch.append([{ + "role": "user", + "content": [ + {"type": "image", "image": path}, + {"type": "text", "text": "Describe this image."}, + ], + }]) - if not image_inputs: - raise MLXEmbeddingError( - "Vision pre-processing produced no image inputs." - ) - if len(image_inputs) != len(image_paths): - raise MLXEmbeddingError( - f"Vision pre-processing count mismatch: expected {len(image_paths)}, got {len(image_inputs)}." - ) + try: + chat_texts = [ + self._apply_chat_template( + self._embedder_processor, + messages, tokenize=False, add_generation_prompt=True, + ) + for messages in messages_batch + ] + except Exception as exc: + raise MLXEmbeddingError( + "Failed to build chat templates for image embedding batch." + ) from exc - try: - # Image processor requires PyTorch tensors, convert to MLX after - inputs = self._embedder_processor( - text=chat_texts, images=image_inputs, - return_tensors="pt", padding=True, - ) - except Exception as exc: - raise MLXEmbeddingError( - "Failed to tokenize image batch with processor." - ) from exc + try: + image_inputs, _ = process_vision_info(messages_batch) + except Exception as exc: + raise MLXEmbeddingError( + "Failed to process vision inputs for image embedding batch." + ) from exc - for required_key in ("input_ids", "pixel_values", "image_grid_thw"): - if required_key not in inputs: + if not image_inputs: + raise MLXEmbeddingError( + "Vision pre-processing produced no image inputs." + ) + if len(image_inputs) != len(image_paths): raise MLXEmbeddingError( - f"Image processor output is missing '{required_key}'." + f"Vision pre-processing count mismatch: expected {len(image_paths)}, got {len(image_inputs)}." ) - input_ids = self._to_mx_array(inputs["input_ids"], "input_ids") - pixel_values = self._to_mx_array(inputs["pixel_values"], "pixel_values") - image_grid_thw = self._to_mx_array(inputs["image_grid_thw"], "image_grid_thw") + try: + # Image processor requires PyTorch tensors, convert to MLX after + inputs = self._embedder_processor( + text=chat_texts, images=image_inputs, + return_tensors="pt", padding=True, + ) + except Exception as exc: + raise MLXEmbeddingError( + "Failed to tokenize image batch with processor." + ) from exc - # Free PyTorch tensors and vision intermediates now that we have MLX arrays - del inputs, image_inputs, chat_texts, messages_batch - gc.collect() + for required_key in ("input_ids", "pixel_values", "image_grid_thw"): + if required_key not in inputs: + raise MLXEmbeddingError( + f"Image processor output is missing '{required_key}'." + ) - try: - cache = _make_cache(num_layers) - except Exception as exc: - raise MLXEmbeddingError( - "Failed to initialize KV cache for image embedding batch." - ) from exc + input_ids = self._to_mx_array(inputs["input_ids"], "input_ids") + pixel_values = self._to_mx_array(inputs["pixel_values"], "pixel_values") + image_grid_thw = self._to_mx_array(inputs["image_grid_thw"], "image_grid_thw") - try: - h = self._embed_hidden_with_media( - input_ids, - pixel_values, - cache, - image_grid_thw=image_grid_thw, - ) - except Exception as exc: - raise MLXEmbeddingError( - "MLX vision forward pass failed." - ) from exc + # Free PyTorch tensors and vision intermediates now that we have MLX arrays + del inputs, image_inputs, chat_texts, messages_batch + gc.collect() - try: - embeddings = self._pool_and_normalize(h) - except Exception as exc: - raise MLXEmbeddingError( - "Failed to pool and normalize image embeddings." - ) from exc + try: + cache = _make_cache(num_layers) + except Exception as exc: + raise MLXEmbeddingError( + "Failed to initialize KV cache for image embedding batch." + ) from exc - if embeddings.shape[0] != len(image_paths): - raise MLXEmbeddingError( - f"Image embedding batch size mismatch: expected {len(image_paths)}, got {embeddings.shape[0]}." - ) + try: + h = self._embed_hidden_with_media( + input_ids, + pixel_values, + cache, + image_grid_thw=image_grid_thw, + ) + except Exception as exc: + raise MLXEmbeddingError( + "MLX vision forward pass failed." + ) from exc - return embeddings + try: + embeddings = self._pool_and_normalize(h) + except Exception as exc: + raise MLXEmbeddingError( + "Failed to pool and normalize image embeddings." + ) from exc + + if embeddings.shape[0] != len(image_paths): + raise MLXEmbeddingError( + f"Image embedding batch size mismatch: expected {len(image_paths)}, got {embeddings.shape[0]}." + ) + + return embeddings def _video_content(self, path: str) -> dict: """Build a video content dict with adaptive frame sampling. @@ -696,7 +887,8 @@ def _format_caption_prompt(self, image_path: str) -> str: ]} ] try: - return self._captioner_processor.apply_chat_template( + return self._apply_chat_template( + self._captioner_processor, messages, tokenize=False, add_generation_prompt=True ) finally: @@ -710,34 +902,35 @@ def caption_image(self, image_path: str) -> str: for the duration of the ingest batch. Call _unload_captioner() after batch completion to reclaim ~0.9 GB. """ - try: - self._load_captioner() - from mlx_vlm import generate as vlm_generate - - prompt = self._format_caption_prompt(image_path) - output = vlm_generate( - self._captioner_model, - self._captioner_processor, - prompt=prompt, - image=[image_path], - max_tokens=self._CAPTION_MAX_TOKENS, - ) - text = output.text if hasattr(output, "text") else str(output) - caption = re.sub(r"\s+", " ", text).strip() - logger.debug( - "caption_image path=%s caption_len=%d caption=%s", - image_path, len(caption), caption[:100], - ) - return caption[:512] # Safety cap for BM25 field length - except Exception as exc: - logger.warning("caption_image failed for %s: %s", image_path, exc) - return "" - finally: + with self._hold_heavy_op("caption_image"): try: - del prompt, output, text, caption - except Exception: - pass - gc.collect() + self._load_captioner() + from mlx_vlm import generate as vlm_generate + + prompt = self._format_caption_prompt(image_path) + output = vlm_generate( + self._captioner_model, + self._captioner_processor, + prompt=prompt, + image=[image_path], + max_tokens=self._CAPTION_MAX_TOKENS, + ) + text = output.text if hasattr(output, "text") else str(output) + caption = re.sub(r"\s+", " ", text).strip() + logger.debug( + "caption_image path=%s caption_len=%d caption=%s", + image_path, len(caption), caption[:100], + ) + return caption[:512] # Safety cap for BM25 field length + except Exception as exc: + logger.warning("caption_image failed for %s: %s", image_path, exc) + return "" + finally: + try: + del prompt, output, text, caption + except Exception: + pass + gc.collect() def describe_image(self, image_path: str) -> str: """Backward-compatible alias for caption_image.""" @@ -769,33 +962,35 @@ def generate_text(self, prompt: str, max_tokens: int = 60) -> str: Used for query expansion and other lightweight generation tasks. Text-only — no image input. """ - try: - self._load_captioner() - from mlx_vlm import generate as vlm_generate - - messages = [{"role": "user", "content": [ - {"type": "text", "text": prompt}, - ]}] - formatted = self._captioner_processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - output = vlm_generate( - self._captioner_model, - self._captioner_processor, - prompt=formatted, - max_tokens=max_tokens, - ) - text = output.text if hasattr(output, "text") else str(output) - return text.strip() - except Exception as exc: - logger.warning("generate_text failed: %s", exc) - return "" - finally: + with self._hold_heavy_op("generate_text"): try: - del messages, formatted, output, text - except Exception: - pass - gc.collect() + self._load_captioner() + from mlx_vlm import generate as vlm_generate + + messages = [{"role": "user", "content": [ + {"type": "text", "text": prompt}, + ]}] + formatted = self._apply_chat_template( + self._captioner_processor, + messages, tokenize=False, add_generation_prompt=True + ) + output = vlm_generate( + self._captioner_model, + self._captioner_processor, + prompt=formatted, + max_tokens=max_tokens, + ) + text = output.text if hasattr(output, "text") else str(output) + return text.strip() + except Exception as exc: + logger.warning("generate_text failed: %s", exc) + return "" + finally: + try: + del messages, formatted, output, text + except Exception: + pass + gc.collect() def embed_videos(self, video_paths: List[str]) -> np.ndarray: """ @@ -804,13 +999,32 @@ def embed_videos(self, video_paths: List[str]) -> np.ndarray: Qwen3-VL's video processor currently expects per-video sampling kwargs, so we process each video independently and stack the resulting vectors. """ - video_paths = self._validate_video_paths(video_paths) - if not video_paths: - return np.empty((0, self._EMBED_DIM), dtype=np.float32) + with self._hold_heavy_op("embed_videos"): + video_paths = self._validate_video_paths(video_paths) + if not video_paths: + return np.empty((0, self._EMBED_DIM), dtype=np.float32) - self._load_embedder() - num_layers = self._get_embedder_num_layers() + self._load_embedder() + num_layers = self._get_embedder_num_layers() + + embeddings: List[np.ndarray] = [] + for path in video_paths: + try: + embedding = self._embed_video_native(path, num_layers) + except Exception as exc: + logger.warning( + "native_video_embedding failed for %s: %s; falling back to frame embeddings", + path, + exc, + ) + embedding = self._embed_video_via_frames(path) + + embeddings.append(embedding) + return np.stack(embeddings).astype(np.float32) + + def _embed_video_native(self, path: str, num_layers: int) -> np.ndarray: + """Embed a video via qwen-vl-utils native video preprocessing.""" try: from qwen_vl_utils import process_vision_info except ImportError as exc: @@ -819,104 +1033,131 @@ def embed_videos(self, video_paths: List[str]) -> np.ndarray: "Install qwen-vl-utils and torchvision for video embeddings." ) from exc - embeddings: List[np.ndarray] = [] - for path in video_paths: - messages = [{ - "role": "user", - "content": [ - self._video_content(path), - {"type": "text", "text": "Describe this video."}, - ], - }] + messages = [{ + "role": "user", + "content": [ + self._video_content(path), + {"type": "text", "text": "Describe this video."}, + ], + }] - try: - chat_text = self._embedder_processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True, - ) - except Exception as exc: - raise MLXEmbeddingError( - f"Failed to build chat template for video '{path}'." - ) from exc + try: + chat_text = self._apply_chat_template( + self._embedder_processor, + messages, + tokenize=False, + add_generation_prompt=True, + ) + except Exception as exc: + raise MLXEmbeddingError( + f"Failed to build chat template for video '{path}'." + ) from exc - try: - _, video_inputs, video_kwargs = process_vision_info( - [messages], - return_video_kwargs=True, - ) - except Exception as exc: - raise MLXEmbeddingError( - f"Failed to process video inputs for '{path}'." - ) from exc + try: + _, video_inputs, video_kwargs = process_vision_info( + [messages], + return_video_kwargs=True, + ) + except Exception as exc: + raise MLXEmbeddingError( + f"Failed to process video inputs for '{path}'." + ) from exc + + if not video_inputs: + raise MLXEmbeddingError( + f"Video pre-processing produced no video inputs for '{path}'." + ) + + normalized_video_kwargs = dict(video_kwargs or {}) + fps_value = normalized_video_kwargs.get("fps") + if isinstance(fps_value, list): + normalized_video_kwargs["fps"] = fps_value[0] if fps_value else None + + try: + inputs = self._embedder_processor( + text=[chat_text], + videos=video_inputs, + return_tensors="pt", + padding=True, + **normalized_video_kwargs, + ) + except Exception as exc: + raise MLXEmbeddingError( + f"Failed to tokenize video '{path}' with processor." + ) from exc - if not video_inputs: + for required_key in ("input_ids", "pixel_values_videos", "video_grid_thw"): + if required_key not in inputs: raise MLXEmbeddingError( - f"Video pre-processing produced no video inputs for '{path}'." + f"Video processor output is missing '{required_key}' for '{path}'." ) - normalized_video_kwargs = dict(video_kwargs or {}) - fps_value = normalized_video_kwargs.get("fps") - if isinstance(fps_value, list): - normalized_video_kwargs["fps"] = fps_value[0] if fps_value else None + input_ids = self._to_mx_array(inputs["input_ids"], "input_ids") + pixel_values = self._to_mx_array(inputs["pixel_values_videos"], "pixel_values_videos") + video_grid_thw = self._to_mx_array(inputs["video_grid_thw"], "video_grid_thw") - try: - inputs = self._embedder_processor( - text=[chat_text], - videos=video_inputs, - return_tensors="pt", - padding=True, - **normalized_video_kwargs, - ) - except Exception as exc: - raise MLXEmbeddingError( - f"Failed to tokenize video '{path}' with processor." - ) from exc + del inputs, video_inputs, messages, chat_text, normalized_video_kwargs, video_kwargs + gc.collect() - for required_key in ("input_ids", "pixel_values_videos", "video_grid_thw"): - if required_key not in inputs: - raise MLXEmbeddingError( - f"Video processor output is missing '{required_key}' for '{path}'." - ) + try: + cache = _make_cache(num_layers) + except Exception as exc: + raise MLXEmbeddingError( + f"Failed to initialize KV cache for video '{path}'." + ) from exc - input_ids = self._to_mx_array(inputs["input_ids"], "input_ids") - pixel_values = self._to_mx_array(inputs["pixel_values_videos"], "pixel_values_videos") - video_grid_thw = self._to_mx_array(inputs["video_grid_thw"], "video_grid_thw") + try: + h = self._embed_hidden_with_media( + input_ids, + pixel_values, + cache, + video_grid_thw=video_grid_thw, + ) + except Exception as exc: + raise MLXEmbeddingError( + f"MLX video forward pass failed for '{path}'." + ) from exc - # Free PyTorch tensors and intermediates now that we have MLX arrays - del inputs, video_inputs, messages, chat_text, normalized_video_kwargs, video_kwargs - gc.collect() + try: + embedding = self._pool_and_normalize(h) + except Exception as exc: + raise MLXEmbeddingError( + f"Failed to pool and normalize video embedding for '{path}'." + ) from exc + finally: + del input_ids, pixel_values, video_grid_thw, cache, h - try: - cache = _make_cache(num_layers) - except Exception as exc: - raise MLXEmbeddingError( - f"Failed to initialize KV cache for video '{path}'." - ) from exc + return embedding[0] - try: - h = self._embed_hidden_with_media( - input_ids, - pixel_values, - cache, - video_grid_thw=video_grid_thw, - ) - except Exception as exc: - raise MLXEmbeddingError( - f"MLX video forward pass failed for '{path}'." - ) from exc + def _embed_video_via_frames(self, path: str) -> np.ndarray: + """Fallback raw-video embedding by averaging ffmpeg-extracted frame vectors.""" + from ..video import extract_video_frames - try: - embedding = self._pool_and_normalize(h) - except Exception as exc: + with tempfile.TemporaryDirectory(prefix="recallforge_video_query_") as temp_dir: + frames, _ = extract_video_frames( + path, + temp_dir, + logical_path=os.path.basename(path), + frame_interval_seconds=5.0, + max_frames=self._VIDEO_FALLBACK_MAX_FRAMES, + ) + frame_paths = [frame.image_path for frame in frames] + if not frame_paths: raise MLXEmbeddingError( - f"Failed to pool and normalize video embedding for '{path}'." - ) from exc - - # Free MLX intermediates before next video - del input_ids, pixel_values, video_grid_thw, cache, h + f"Video fallback produced no frames for '{path}'. Ensure ffmpeg/ffprobe are installed." + ) - embeddings.append(embedding[0]) + frame_embeddings = self.embed_images(frame_paths) + if frame_embeddings.size == 0: + raise MLXEmbeddingError( + f"Video fallback frame embeddings were empty for '{path}'." + ) - return np.stack(embeddings).astype(np.float32) + pooled = frame_embeddings.mean(axis=0) + norm = float(np.linalg.norm(pooled)) + if norm > 0: + pooled = pooled / norm + return pooled.astype(np.float32) # ========================================================================= # Reranker @@ -1092,7 +1333,8 @@ def _render_reranker_prompt( """Render reranker chat messages to text and report whether vision tokens survived.""" try: return ( - self._reranker_processor.apply_chat_template( + self._apply_chat_template( + self._reranker_processor, messages, tokenize=False, add_generation_prompt=True, @@ -1509,55 +1751,59 @@ def rerank( Returns list of scores. Scoring path information is logged via logger.debug. """ - if not documents: - return [] - - if not self.needs_reranker(): - return [0.5] * len(documents) + with self._hold_heavy_op("rerank"): + if not documents: + return [] - try: - self._load_reranker() - num_layers = self._reranker_model.language_model.model.num_hidden_layers - except Exception as e: - logger.error(f"[MLXBackend] Failed to initialize reranker: {e}") - return [0.5] * len(documents) + if not self.needs_reranker(): + return [0.5] * len(documents) - instruction = self._RERANK_DEFAULT_INSTRUCTION - scores: List[float] = [] - for idx, doc in enumerate(documents): try: - text = doc.get("text", "") or doc.get("text_body", "") or "" - doc_image_path = doc.get("image_path") - doc_video_path = doc.get("video_path") - messages = self._build_reranker_messages( - query, - text, - instruction, - image_path=doc_image_path, video_path=doc_video_path, - query_image_path=query_image_path, - query_video_path=query_video_path, - ) - prompt, template_ok = self._render_reranker_prompt( - messages, - query, - text, - instruction, - ) - score, scoring_path, raw_score = self._score_reranker_prompt( - prompt, num_layers, - messages=messages if template_ok else None, - ) - scores.append(score) - # Log per-document reranker path tracing - logger.debug("reranker_doc idx=%d path=%s raw_score=%.4f final_score=%.4f content_type=%s", - idx, scoring_path, raw_score, score, doc.get("content_type", "unknown")) + self._load_reranker() + num_layers = self._reranker_model.language_model.model.num_hidden_layers except Exception as e: - logger.error(f"[MLXBackend] Rerank error at doc {idx}: {e}") - scores.append(0.5) - logger.debug("reranker_doc idx=%d path=error_fallback raw_score=0.0 final_score=0.5 content_type=%s", - idx, doc.get("content_type", "unknown")) + logger.error(f"[MLXBackend] Failed to initialize reranker: {e}") + return [0.5] * len(documents) - return scores + instruction = self._RERANK_DEFAULT_INSTRUCTION + scores: List[float] = [] + for idx, doc in enumerate(documents): + try: + text = doc.get("text", "") or doc.get("text_body", "") or "" + doc_image_path = doc.get("image_path") + doc_video_path = doc.get("video_path") + messages = self._build_reranker_messages( + query, + text, + instruction, + image_path=doc_image_path, video_path=doc_video_path, + query_image_path=query_image_path, + query_video_path=query_video_path, + ) + prompt, template_ok = self._render_reranker_prompt( + messages, + query, + text, + instruction, + ) + score, scoring_path, raw_score = self._score_reranker_prompt( + prompt, num_layers, + messages=messages if template_ok else None, + ) + scores.append(score) + logger.debug( + "reranker_doc idx=%d path=%s raw_score=%.4f final_score=%.4f content_type=%s", + idx, scoring_path, raw_score, score, doc.get("content_type", "unknown") + ) + except Exception as e: + logger.error(f"[MLXBackend] Rerank error at doc {idx}: {e}") + scores.append(0.5) + logger.debug( + "reranker_doc idx=%d path=error_fallback raw_score=0.0 final_score=0.5 content_type=%s", + idx, doc.get("content_type", "unknown") + ) + + return scores # ========================================================================= # Warm-up and Status @@ -1567,20 +1813,21 @@ def warm_up(self) -> None: """Preload models and run a dummy embed pass to prime MLX compilation.""" import time - logger.info(f"[MLXBackend] Warming up (mode={self._mode}, quant={self._quantization})...") - start = time.time() + with self._hold_heavy_op("warm_up"): + logger.info(f"[MLXBackend] Warming up (mode={self._mode}, quant={self._quantization})...") + start = time.time() - self._load_embedder() - self._warm_embed() - t1 = time.time() - logger.info(f"[MLXBackend] Embedder+compile: {t1 - start:.1f}s") + self._load_embedder() + self._warm_embed() + t1 = time.time() + logger.info(f"[MLXBackend] Embedder+compile: {t1 - start:.1f}s") - if self.needs_reranker(): - self._load_reranker() - t2 = time.time() - logger.info(f"[MLXBackend] Reranker: {t2 - t1:.1f}s") + if self.needs_reranker(): + self._load_reranker() + t2 = time.time() + logger.info(f"[MLXBackend] Reranker: {t2 - t1:.1f}s") - logger.info(f"[MLXBackend] Ready in {time.time() - start:.1f}s") + logger.info(f"[MLXBackend] Ready in {time.time() - start:.1f}s") def get_info(self) -> BackendInfo: """Return backend information.""" diff --git a/src/recallforge/search.py b/src/recallforge/search.py index bce3ad5..30d1739 100644 --- a/src/recallforge/search.py +++ b/src/recallforge/search.py @@ -420,6 +420,7 @@ def __init__( cache: Optional[EmbeddingCache] = None, intent: Optional[str] = None, expand: bool = False, + enable_media_query_probe: bool = True, ): """ Initialize hybrid searcher. @@ -443,6 +444,8 @@ def __init__( cache: Optional EmbeddingCache; created with default maxsize if None intent: Optional intent for query steering ("exact_lookup", "semantic", "broad") expand: Whether to enable VL-aware query expansion (default: False, opt-in) + enable_media_query_probe: Whether image/video queries should generate + caption/transcript BM25 probe text before fusion. Default True. """ self.backend = backend self.storage = storage @@ -490,6 +493,7 @@ def __init__( self.cache: EmbeddingCache = cache if cache is not None else EmbeddingCache() self.intent = intent self.expand = expand + self.enable_media_query_probe = enable_media_query_probe def _vector_results_to_hybrid(self, results: List[SearchResult]) -> List[HybridResult]: """Convert raw vector results into HybridResult objects.""" @@ -624,6 +628,9 @@ def _query_media_probe( video_path: Optional[str] = None, ) -> tuple[str, List[SearchResult]]: """Generate a text probe from query media and run BM25 when possible.""" + if not self.enable_media_query_probe: + return "", [] + query_text = "" if image_path: query_text = self._caption_image_query(image_path) @@ -670,9 +677,11 @@ def _add_text_expansion_branches( def search_image(self, image_path: str) -> List[HybridResult]: """Run image-query search through hybrid pipeline (RRF + optional rerank).""" # Image query always contributes vector candidates. - vector = self._embed_image_cached(image_path) - all_results: Dict[str, List[SearchResult]] = { - "original_vec": self.storage.search_vec( + all_results: Dict[str, List[SearchResult]] = {} + query_image_path_for_rerank: Optional[str] = image_path + try: + vector = self._embed_image_cached(image_path) + all_results["original_vec"] = self.storage.search_vec( vector.tolist() if hasattr(vector, 'tolist') else list(vector), limit=self.fts_probe_limit, collection=self.collection, @@ -682,16 +691,21 @@ def search_image(self, image_path: str) -> List[HybridResult]: project_id=self.project_id, profile=self.profile, ) - } + except Exception as exc: + logger.warning("image query embedding failed for %s: %s", image_path, exc) + query_image_path_for_rerank = None query_text, bm25_results = self._query_media_probe(image_path=image_path) if bm25_results: all_results["original_fts"] = bm25_results self._add_text_expansion_branches(all_results, query_text) + if not all_results: + return [] + candidates, rrf_audit_info = self._reciprocal_rank_fusion(all_results) rerank_scores, reranker_path = self._rerank_candidates( - candidates, query=query_text, query_image_path=image_path + candidates, query=query_text, query_image_path=query_image_path_for_rerank ) return self._blend_scores(candidates, rerank_scores, rrf_audit_info, reranker_path) @@ -707,9 +721,11 @@ def search_video(self, video_path: str) -> List[HybridResult]: f"Backend {type(self.backend).__name__} does not support raw video queries. " "Install a backend with video support (e.g. recallforge[mlx] or recallforge[torch])." ) - vector = embed_video(video_path) - all_results: Dict[str, List[SearchResult]] = { - "original_vec": self.storage.search_vec( + all_results: Dict[str, List[SearchResult]] = {} + query_video_path_for_rerank: Optional[str] = video_path + try: + vector = embed_video(video_path) + all_results["original_vec"] = self.storage.search_vec( vector.tolist() if hasattr(vector, 'tolist') else list(vector), limit=self.fts_probe_limit, collection=self.collection, @@ -719,16 +735,21 @@ def search_video(self, video_path: str) -> List[HybridResult]: project_id=self.project_id, profile=self.profile, ) - } + except Exception as exc: + logger.warning("video query embedding failed for %s: %s", video_path, exc) + query_video_path_for_rerank = None query_text, bm25_results = self._query_media_probe(video_path=video_path) if bm25_results: all_results["original_fts"] = bm25_results self._add_text_expansion_branches(all_results, query_text) + if not all_results: + return [] + candidates, rrf_audit_info = self._reciprocal_rank_fusion(all_results) rerank_scores, reranker_path = self._rerank_candidates( - candidates, query=query_text, query_video_path=video_path + candidates, query=query_text, query_video_path=query_video_path_for_rerank ) return self._blend_scores(candidates, rerank_scores, rrf_audit_info, reranker_path) diff --git a/tests/test_cross_modal_benchmark_defs.py b/tests/test_cross_modal_benchmark_defs.py index 969f9c9..a2a5e2d 100644 --- a/tests/test_cross_modal_benchmark_defs.py +++ b/tests/test_cross_modal_benchmark_defs.py @@ -248,6 +248,7 @@ def test_output_payload_tracks_partial_progress(self): categories, {"Vector-only": {"text_to_text": stage_result}}, [("Vector-only", "vector")], + expansion_profile=module._resolve_expansion_profile("caption_only"), indexed_items=74, run_status="partial", interrupted=True, @@ -258,6 +259,9 @@ def test_output_payload_tracks_partial_progress(self): ) self.assertEqual(payload["version"], "0.2.0") + self.assertEqual(payload["configuration"]["expansion_profile"], "caption_only") + self.assertFalse(payload["configuration"]["expand_enabled"]) + self.assertTrue(payload["configuration"]["media_query_probe_enabled"]) self.assertEqual(payload["run_status"], "partial") self.assertTrue(payload["interrupted"]) self.assertEqual(payload["progress"]["indexed_items"], 74) @@ -277,6 +281,50 @@ def test_output_payload_tracks_partial_progress(self): payload["stages"]["Vector-only"]["text_to_text"]["per_query_results"][0]["asset_level"]["hit_at_1"] ) + def test_resolve_expansion_profile_variants(self): + module = _load_cross_modal_ablation() + + qwen = module._resolve_expansion_profile("qwen") + off = module._resolve_expansion_profile("off") + + self.assertTrue(qwen.expand) + self.assertTrue(qwen.allow_generate_text) + self.assertFalse(off.expand) + self.assertFalse(off.enable_media_query_probe) + + with self.assertRaises(ValueError): + module._resolve_expansion_profile("bogus") + + def test_expansion_backend_proxy_can_disable_generate_text(self): + module = _load_cross_modal_ablation() + + class _Backend: + def __init__(self): + self.calls = [] + + def generate_text(self, prompt: str, max_tokens: int = 60) -> str: + self.calls.append((prompt, max_tokens)) + return "ok" + + backend = _Backend() + disabled = module._ExpansionBackendProxy(backend, allow_generate_text=False) + enabled = module._ExpansionBackendProxy(backend, allow_generate_text=True) + + with self.assertRaises(NotImplementedError): + disabled.generate_text("prompt") + + self.assertEqual(enabled.generate_text("prompt", max_tokens=12), "ok") + self.assertEqual(backend.calls, [("prompt", 12)]) + + def test_resolve_output_path_suffixes_non_default_profiles(self): + module = _load_cross_modal_ablation() + + default_path = module._resolve_output_path(None, "caption_only") + qwen_path = module._resolve_output_path(None, "qwen") + + self.assertTrue(default_path.endswith("cross_modal_ablation_results.json")) + self.assertTrue(qwen_path.endswith("cross_modal_ablation_results_qwen.json")) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_mlx_reranker_prompt.py b/tests/test_mlx_reranker_prompt.py index 7d516a4..91ef70d 100644 --- a/tests/test_mlx_reranker_prompt.py +++ b/tests/test_mlx_reranker_prompt.py @@ -6,8 +6,11 @@ import os import sys +import tempfile import types import unittest +from contextlib import contextmanager +from types import SimpleNamespace from unittest.mock import patch import numpy as np @@ -39,6 +42,45 @@ def __call__(self, **kwargs): } +class _TokenizerOnlyChatTemplate: + def __init__(self): + self.calls = [] + + def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True): + self.calls.append( + { + "messages": messages, + "tokenize": tokenize, + "add_generation_prompt": add_generation_prompt, + } + ) + return "TOKENIZER_TEMPLATE" + + +class _ProcessorWithoutChatTemplate: + def __init__(self): + self.tokenizer = _TokenizerOnlyChatTemplate() + + def apply_chat_template(self, *_args, **_kwargs): + raise ValueError("Cannot use apply_chat_template because this processor does not have a chat template.") + + +class _ProcessorWithChatTemplate: + def __init__(self): + self.calls = [] + self.tokenizer = _TokenizerOnlyChatTemplate() + + def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True): + self.calls.append( + { + "messages": messages, + "tokenize": tokenize, + "add_generation_prompt": add_generation_prompt, + } + ) + return "PROCESSOR_TEMPLATE" + + class TestMLXRerankerPromptPreparation(unittest.TestCase): def _make_backend(self): backend = object.__new__(mlx_backend.MLXBackend) @@ -122,6 +164,151 @@ def test_text_only_messages_skip_vision_preprocessing(self): {"text": "PROMPT", "return_tensors": "np"}, ) + def test_apply_chat_template_falls_back_to_tokenizer_template(self): + backend = object.__new__(mlx_backend.MLXBackend) + processor = _ProcessorWithoutChatTemplate() + + rendered = backend._apply_chat_template( + processor, + [{"role": "user", "content": [{"type": "text", "text": "hello"}]}], + tokenize=False, + add_generation_prompt=True, + ) + + self.assertEqual(rendered, "TOKENIZER_TEMPLATE") + self.assertEqual(len(processor.tokenizer.calls), 1) + + def test_apply_chat_template_prefers_processor_when_available(self): + backend = object.__new__(mlx_backend.MLXBackend) + processor = _ProcessorWithChatTemplate() + + rendered = backend._apply_chat_template( + processor, + [{"role": "user", "content": [{"type": "text", "text": "hello"}]}], + tokenize=False, + add_generation_prompt=True, + ) + + self.assertEqual(rendered, "PROCESSOR_TEMPLATE") + self.assertEqual(len(processor.calls), 1) + self.assertEqual(len(processor.tokenizer.calls), 0) + + def test_resolve_heavy_op_concurrency_uses_env_and_falls_back(self): + backend = object.__new__(mlx_backend.MLXBackend) + backend._DEFAULT_HEAVY_OP_CONCURRENCY = 1 + + with patch.dict(os.environ, {"RECALLFORGE_MLX_HEAVY_OP_CONCURRENCY": "3"}): + self.assertEqual(backend._resolve_heavy_op_concurrency(), 3) + + with patch.dict(os.environ, {"RECALLFORGE_MLX_HEAVY_OP_CONCURRENCY": "0"}): + self.assertEqual(backend._resolve_heavy_op_concurrency(), 1) + + def test_heavy_op_gate_is_reentrant_within_same_thread(self): + backend = object.__new__(mlx_backend.MLXBackend) + backend._DEFAULT_HEAVY_OP_CONCURRENCY = 1 + + mlx_backend._HEAVY_OP_GATE = None + mlx_backend._HEAVY_OP_GATE_LIMIT = None + + with patch.dict(os.environ, {"RECALLFORGE_MLX_HEAVY_OP_CONCURRENCY": "1"}): + with backend._hold_heavy_op("outer"): + with backend._hold_heavy_op("inner"): + gate = backend._get_heavy_op_gate() + self.assertEqual(getattr(gate._thread_state, "depth", 0), 2) + + gate = backend._get_heavy_op_gate() + self.assertEqual(getattr(gate._thread_state, "depth", 0), 0) + + def test_embed_videos_falls_back_to_frame_embeddings_when_native_path_fails(self): + backend = object.__new__(mlx_backend.MLXBackend) + backend._validate_video_paths = lambda paths: paths + backend._load_embedder = lambda: None + backend._get_embedder_num_layers = lambda: 2 + + def _raise_native(_path, _num_layers): + raise mlx_backend.MLXEmbeddingError("native failed") + + backend._embed_video_native = _raise_native + backend._embed_video_via_frames = lambda _path: np.array([0.6, 0.8], dtype=np.float32) + + embeddings = backend.embed_videos(["clip.mp4"]) + + np.testing.assert_allclose( + embeddings, + np.array([[0.6, 0.8]], dtype=np.float32), + ) + + def test_embed_video_via_frames_averages_and_normalizes_frame_embeddings(self): + backend = object.__new__(mlx_backend.MLXBackend) + backend._VIDEO_MAX_FRAMES = 128 + backend.embed_images = lambda _paths: np.array( + [[1.0, 0.0], [0.0, 1.0]], + dtype=np.float32, + ) + + frames = [ + SimpleNamespace(image_path="frame1.png"), + SimpleNamespace(image_path="frame2.png"), + ] + + with patch("recallforge.video.extract_video_frames", return_value=(frames, None)): + embedding = backend._embed_video_via_frames("clip.mp4") + + np.testing.assert_allclose( + embedding, + np.array([0.70710677, 0.70710677], dtype=np.float32), + rtol=1e-5, + atol=1e-5, + ) + + def test_resolve_heavy_op_concurrency_defaults_to_one(self): + backend = object.__new__(mlx_backend.MLXBackend) + + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("RECALLFORGE_MLX_HEAVY_OP_CONCURRENCY", None) + self.assertEqual(backend._resolve_heavy_op_concurrency(), 1) + + def test_resolve_heavy_op_concurrency_invalid_value_falls_back_to_one(self): + backend = object.__new__(mlx_backend.MLXBackend) + + with patch.dict(os.environ, {"RECALLFORGE_MLX_HEAVY_OP_CONCURRENCY": "nope"}): + self.assertEqual(backend._resolve_heavy_op_concurrency(), 1) + + def test_heavy_op_gate_is_reentrant_on_same_thread(self): + with tempfile.NamedTemporaryFile(delete=False) as tmp: + lock_path = tmp.name + try: + gate = mlx_backend._HeavyOpGate(limit=1, lock_path=lock_path) + with gate.hold("outer"): + with gate.hold("inner"): + self.assertEqual(getattr(gate._thread_state, "depth", 0), 2) + self.assertFalse(hasattr(gate._thread_state, "depth")) + finally: + os.unlink(lock_path) + + def test_embed_videos_uses_heavy_op_guard(self): + backend = object.__new__(mlx_backend.MLXBackend) + calls = [] + + @contextmanager + def fake_hold(name): + calls.append(name) + yield + + backend._hold_heavy_op = fake_hold + backend._validate_video_paths = lambda paths: paths + backend._load_embedder = lambda: None + backend._get_embedder_num_layers = lambda: 2 + backend._embed_video_native = lambda _path, _num_layers: np.array([1.0, 0.0], dtype=np.float32) + + embeddings = backend.embed_videos(["clip.mp4"]) + + self.assertEqual(calls, ["embed_videos"]) + np.testing.assert_allclose( + embeddings, + np.array([[1.0, 0.0]], dtype=np.float32), + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_query_expansion.py b/tests/test_query_expansion.py index d7e318d..b6ac347 100644 --- a/tests/test_query_expansion.py +++ b/tests/test_query_expansion.py @@ -219,13 +219,16 @@ def test_searcher_accepts_expand_parameter(self): expand=True, ) assert searcher.expand is True + assert searcher.enable_media_query_probe is True searcher_false = HybridSearcher( backend=backend, storage=storage, expand=False, + enable_media_query_probe=False, ) assert searcher_false.expand is False + assert searcher_false.enable_media_query_probe is False def test_searcher_defaults_expand_to_false(self): """HybridSearcher should default expand to False.""" @@ -237,6 +240,7 @@ def test_searcher_defaults_expand_to_false(self): storage=storage, ) assert searcher.expand is False + assert searcher.enable_media_query_probe is True class TestQueryExpansionIntegration: diff --git a/tests/test_search_pipeline.py b/tests/test_search_pipeline.py index f0105a3..4017b25 100644 --- a/tests/test_search_pipeline.py +++ b/tests/test_search_pipeline.py @@ -658,6 +658,47 @@ def test_search_image_uses_caption_for_bm25_probe(self): self.assertEqual(len(results), 1) self.assertIn("original_fts", results[0].source) + def test_search_image_can_disable_media_query_probe(self): + backend = StubBackend(mode="hybrid") + backend.rerank = MagicMock(return_value=[0.88]) + storage = StubStorage( + fts_results=[_make_search_result("doc1.md", 0.95, "fts")], + vec_results=[_make_search_result("doc1.md", 0.9, "vec")], + ) + storage.search_fts = MagicMock(side_effect=storage.search_fts) + searcher = HybridSearcher( + backend=backend, + storage=storage, + limit=1, + enable_media_query_probe=False, + ) + + results = searcher.search_image("/tmp/query.png") + + self.assertEqual(len(results), 1) + storage.search_fts.assert_not_called() + self.assertEqual(results[0].source, "original_vec") + + def test_search_image_falls_back_to_caption_probe_when_query_embedding_fails(self): + backend = StubBackend(mode="hybrid") + backend.rerank = MagicMock(return_value=[0.88]) + backend.embed_image = MagicMock(side_effect=RuntimeError("image embed failed")) + storage = StubStorage( + fts_results=[_make_search_result("doc1.md", 0.95, "fts")], + vec_results=[], + ) + searcher = HybridSearcher(backend=backend, storage=storage, limit=1) + + with self.assertLogs("recallforge.search", level="WARNING") as captured: + results = searcher.search_image("/tmp/query.png") + + self.assertEqual(len(results), 1) + self.assertEqual(storage.last_fts_query, "caption for query.png") + self.assertIn("original_fts", results[0].source) + self.assertTrue( + any("image query embedding failed" in message for message in captured.output) + ) + def test_search_image_expands_caption_probe_when_enabled(self): backend = StubBackend( mode="hybrid", @@ -769,6 +810,28 @@ def test_search_video_expands_caption_probe_when_enabled(self): ) self.assertEqual(storage.search_vec.call_count, 3) + def test_search_video_falls_back_to_caption_probe_when_query_embedding_fails(self): + backend = StubBackend(mode="hybrid") + backend.embed_video = MagicMock(side_effect=RuntimeError("video embed failed")) + backend.rerank = MagicMock(return_value=[0.88]) + storage = StubStorage( + fts_results=[_make_search_result("clip.md", 0.95, "fts")], + vec_results=[], + ) + searcher = HybridSearcher(backend=backend, storage=storage, limit=1) + searcher._caption_video_query = MagicMock(return_value="forest timelapse mountains") + + with self.assertLogs("recallforge.search", level="WARNING") as captured: + results = searcher.search_video("/tmp/query.mp4") + + self.assertEqual(len(results), 1) + self.assertEqual(storage.last_fts_query, "forest timelapse mountains") + searcher._caption_video_query.assert_called_once_with("/tmp/query.mp4") + self.assertIn("original_fts", results[0].source) + self.assertTrue( + any("video query embedding failed" in message for message in captured.output) + ) + def test_text_query_with_media_candidates_skips_reranker_by_default(self): backend = StubBackend(mode="hybrid") backend.rerank = MagicMock(return_value=[0.99]) From 39ea5b3f2f2c02cc7d0ef9258e14bb80f9fc0517 Mon Sep 17 00:00:00 2001 From: MollyAI Date: Fri, 27 Mar 2026 22:40:57 -0400 Subject: [PATCH 2/7] Add MLX heavy-op guard regression tests --- tests/test_mlx_reranker_prompt.py | 64 ++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/tests/test_mlx_reranker_prompt.py b/tests/test_mlx_reranker_prompt.py index 91ef70d..acc750f 100644 --- a/tests/test_mlx_reranker_prompt.py +++ b/tests/test_mlx_reranker_prompt.py @@ -240,7 +240,8 @@ def _raise_native(_path, _num_layers): def test_embed_video_via_frames_averages_and_normalizes_frame_embeddings(self): backend = object.__new__(mlx_backend.MLXBackend) - backend._VIDEO_MAX_FRAMES = 128 + backend._VIDEO_MAX_FRAMES = 32 + backend._VIDEO_FALLBACK_MAX_FRAMES = 2 backend.embed_images = lambda _paths: np.array( [[1.0, 0.0], [0.0, 1.0]], dtype=np.float32, @@ -261,6 +262,32 @@ def test_embed_video_via_frames_averages_and_normalizes_frame_embeddings(self): atol=1e-5, ) + def test_embed_video_via_frames_uses_configured_fallback_cap(self): + backend = object.__new__(mlx_backend.MLXBackend) + backend._VIDEO_FALLBACK_MAX_FRAMES = 3 + backend.embed_images = lambda _paths: np.array([[1.0, 0.0]], dtype=np.float32) + + def fake_extract_video_frames( + path, + temp_dir, + logical_path, + frame_interval_seconds, + max_frames, + ): + self.assertEqual(path, "clip.mp4") + self.assertEqual(logical_path, "clip.mp4") + self.assertEqual(frame_interval_seconds, 5.0) + self.assertEqual(max_frames, 3) + return ([SimpleNamespace(image_path="frame1.png")], None) + + with patch("recallforge.video.extract_video_frames", side_effect=fake_extract_video_frames): + embedding = backend._embed_video_via_frames("clip.mp4") + + np.testing.assert_allclose( + embedding, + np.array([1.0, 0.0], dtype=np.float32), + ) + def test_resolve_heavy_op_concurrency_defaults_to_one(self): backend = object.__new__(mlx_backend.MLXBackend) @@ -274,6 +301,41 @@ def test_resolve_heavy_op_concurrency_invalid_value_falls_back_to_one(self): with patch.dict(os.environ, {"RECALLFORGE_MLX_HEAVY_OP_CONCURRENCY": "nope"}): self.assertEqual(backend._resolve_heavy_op_concurrency(), 1) + def test_resolve_positive_video_envs_fall_back_gracefully(self): + backend = object.__new__(mlx_backend.MLXBackend) + + with patch.dict( + os.environ, + { + "RECALLFORGE_MLX_VIDEO_MAX_FRAMES": "48", + "RECALLFORGE_MLX_VIDEO_SAMPLE_FPS": "0.5", + }, + ): + self.assertEqual( + backend._resolve_positive_int_env("RECALLFORGE_MLX_VIDEO_MAX_FRAMES", 32), + 48, + ) + self.assertEqual( + backend._resolve_positive_float_env("RECALLFORGE_MLX_VIDEO_SAMPLE_FPS", 1.0), + 0.5, + ) + + with patch.dict( + os.environ, + { + "RECALLFORGE_MLX_VIDEO_MAX_FRAMES": "0", + "RECALLFORGE_MLX_VIDEO_SAMPLE_FPS": "-2", + }, + ): + self.assertEqual( + backend._resolve_positive_int_env("RECALLFORGE_MLX_VIDEO_MAX_FRAMES", 32), + 32, + ) + self.assertEqual( + backend._resolve_positive_float_env("RECALLFORGE_MLX_VIDEO_SAMPLE_FPS", 1.0), + 1.0, + ) + def test_heavy_op_gate_is_reentrant_on_same_thread(self): with tempfile.NamedTemporaryFile(delete=False) as tmp: lock_path = tmp.name From 5225343d8bbb9c376072754b992233c05e3a056a Mon Sep 17 00:00:00 2001 From: MollyAI Date: Fri, 27 Mar 2026 22:43:29 -0400 Subject: [PATCH 3/7] Harden MLX video query fallback behavior --- src/recallforge/__init__.py | 1 + src/recallforge/search.py | 58 ++++++++++++++++++++++------------- tests/test_search_pipeline.py | 39 +++++++++++++++++++++++ 3 files changed, 77 insertions(+), 21 deletions(-) diff --git a/src/recallforge/__init__.py b/src/recallforge/__init__.py index f64d443..68bc289 100644 --- a/src/recallforge/__init__.py +++ b/src/recallforge/__init__.py @@ -56,6 +56,7 @@ def _has_torch() -> bool: "RECALLFORGE_MAX_CANDIDATES": "Hard cap for candidate pool before reranking.", "RECALLFORGE_RERANK_TOP_K": "Number of top RRF candidates sent to reranker.", "RECALLFORGE_ENABLE_MEDIA_RERANKING": "Enable multimodal reranking for image/video-involved searches (disabled by default).", + "RECALLFORGE_ENABLE_RAW_VIDEO_QUERY_EMBEDDING": "Enable raw video query embedding; MLX defaults to safer caption/transcript-first retrieval unless explicitly enabled.", "RECALLFORGE_MEDIA_QUERY_RERANK_TOP_K": "Rerank cap for query-side image/video searches.", "RECALLFORGE_MEDIA_RESULT_RERANK_TOP_K": "Rerank cap when text queries retrieve image/video candidates.", "RECALLFORGE_DISABLE_MLX": "Force-disable MLX backend detection (1=true).", diff --git a/src/recallforge/search.py b/src/recallforge/search.py index 30d1739..843aaec 100644 --- a/src/recallforge/search.py +++ b/src/recallforge/search.py @@ -40,6 +40,11 @@ def _env_flag(name: str, default: bool) -> bool: return raw.strip().lower() in {"1", "true", "yes", "on"} +def _is_mlx_backend(backend: ModelBackend) -> bool: + """Best-effort detection for the MLX backend without importing it here.""" + return type(backend).__name__ == "MLXBackend" + + def _log_stage_metrics( stage: str, results: List[Any], @@ -478,6 +483,10 @@ def __init__( "RECALLFORGE_ENABLE_MEDIA_RERANKING", False, ) + self.enable_raw_video_query_embedding = _env_flag( + "RECALLFORGE_ENABLE_RAW_VIDEO_QUERY_EMBEDDING", + not _is_mlx_backend(self.backend), + ) self.overfetch_factor = max(2, env_overfetch) self.max_candidates = max(self.limit, env_max_candidates) self.candidate_limit = min(self.max_candidates, self.limit * self.overfetch_factor) @@ -715,29 +724,36 @@ def search_video(self, video_path: str) -> List[HybridResult]: Raises: NotImplementedError: If the backend does not support native video embedding. """ - embed_video = getattr(self.backend, "embed_video", None) - if not callable(embed_video): - raise NotImplementedError( - f"Backend {type(self.backend).__name__} does not support raw video queries. " - "Install a backend with video support (e.g. recallforge[mlx] or recallforge[torch])." - ) all_results: Dict[str, List[SearchResult]] = {} - query_video_path_for_rerank: Optional[str] = video_path - try: - vector = embed_video(video_path) - all_results["original_vec"] = self.storage.search_vec( - vector.tolist() if hasattr(vector, 'tolist') else list(vector), - limit=self.fts_probe_limit, - collection=self.collection, - content_type=self.content_type, - user_id=self.user_id, - session_id=self.session_id, - project_id=self.project_id, - profile=self.profile, + query_video_path_for_rerank: Optional[str] = None + embed_video = getattr(self.backend, "embed_video", None) + if self.enable_raw_video_query_embedding: + if not callable(embed_video): + raise NotImplementedError( + f"Backend {type(self.backend).__name__} does not support raw video queries. " + "Install a backend with video support (e.g. recallforge[mlx] or recallforge[torch])." + ) + query_video_path_for_rerank = video_path + try: + vector = embed_video(video_path) + all_results["original_vec"] = self.storage.search_vec( + vector.tolist() if hasattr(vector, 'tolist') else list(vector), + limit=self.fts_probe_limit, + collection=self.collection, + content_type=self.content_type, + user_id=self.user_id, + session_id=self.session_id, + project_id=self.project_id, + profile=self.profile, + ) + except Exception as exc: + logger.warning("video query embedding failed for %s: %s", video_path, exc) + query_video_path_for_rerank = None + else: + logger.info( + "raw video query embedding disabled for backend=%s; using caption/transcript-first retrieval", + type(self.backend).__name__, ) - except Exception as exc: - logger.warning("video query embedding failed for %s: %s", video_path, exc) - query_video_path_for_rerank = None query_text, bm25_results = self._query_media_probe(video_path=video_path) if bm25_results: diff --git a/tests/test_search_pipeline.py b/tests/test_search_pipeline.py index 4017b25..ebe06b4 100644 --- a/tests/test_search_pipeline.py +++ b/tests/test_search_pipeline.py @@ -832,6 +832,45 @@ def test_search_video_falls_back_to_caption_probe_when_query_embedding_fails(sel any("video query embedding failed" in message for message in captured.output) ) + def test_search_video_skips_raw_embedding_for_mlx_backend_by_default(self): + MLXStubBackend = type("MLXBackend", (StubBackend,), {}) + backend = MLXStubBackend(mode="hybrid") + backend.embed_video = MagicMock(return_value=np.ones(2048, dtype=np.float32)) + storage = StubStorage( + fts_results=[_make_search_result("clip.md", 0.95, "fts")], + vec_results=[_make_search_result("clip.md", 0.9, "vec")], + ) + searcher = HybridSearcher(backend=backend, storage=storage, limit=1) + searcher._caption_video_query = MagicMock(return_value="forest timelapse mountains") + + with self.assertLogs("recallforge.search", level="INFO") as captured: + results = searcher.search_video("/tmp/query.mp4") + + backend.embed_video.assert_not_called() + self.assertEqual(storage.last_fts_query, "forest timelapse mountains") + self.assertEqual(len(results), 1) + self.assertTrue( + any("raw video query embedding disabled" in message for message in captured.output) + ) + + def test_search_video_can_opt_back_into_raw_embedding_for_mlx_backend(self): + MLXStubBackend = type("MLXBackend", (StubBackend,), {}) + backend = MLXStubBackend(mode="hybrid") + backend.embed_video = MagicMock(return_value=np.ones(2048, dtype=np.float32)) + storage = StubStorage( + fts_results=[_make_search_result("clip.md", 0.95, "fts")], + vec_results=[_make_search_result("clip.md", 0.9, "vec")], + ) + + with patch.dict(os.environ, {"RECALLFORGE_ENABLE_RAW_VIDEO_QUERY_EMBEDDING": "1"}): + searcher = HybridSearcher(backend=backend, storage=storage, limit=1) + searcher._caption_video_query = MagicMock(return_value="forest timelapse mountains") + results = searcher.search_video("/tmp/query.mp4") + + backend.embed_video.assert_called_once_with("/tmp/query.mp4") + self.assertEqual(storage.last_fts_query, "forest timelapse mountains") + self.assertEqual(len(results), 1) + def test_text_query_with_media_candidates_skips_reranker_by_default(self): backend = StubBackend(mode="hybrid") backend.rerank = MagicMock(return_value=[0.99]) From 7b0a8540419fa493b61f22a7874bf6fa94aa8588 Mon Sep 17 00:00:00 2001 From: MollyAI Date: Fri, 27 Mar 2026 22:44:54 -0400 Subject: [PATCH 4/7] Bound MLX processor media budgets --- src/recallforge/__init__.py | 2 ++ src/recallforge/backends/mlx_backend.py | 37 +++++++++++++++++++++++++ tests/test_mlx_reranker_prompt.py | 27 ++++++++++++++++++ 3 files changed, 66 insertions(+) diff --git a/src/recallforge/__init__.py b/src/recallforge/__init__.py index 68bc289..6a6652b 100644 --- a/src/recallforge/__init__.py +++ b/src/recallforge/__init__.py @@ -48,6 +48,8 @@ def _has_torch() -> bool: "RECALLFORGE_MLX_VIDEO_SAMPLE_FPS": "Sampling rate for MLX raw-video processing (lower is safer).", "RECALLFORGE_MLX_VIDEO_MAX_FRAMES": "Frame cap for MLX raw-video processing (default tuned for local safety).", "RECALLFORGE_MLX_VIDEO_FALLBACK_MAX_FRAMES": "Frame cap for ffmpeg frame-averaging fallback when native video embedding is unavailable.", + "RECALLFORGE_MLX_MIN_PIXELS": "Lower bound for MLX processor visual resolution budgeting.", + "RECALLFORGE_MLX_MAX_PIXELS": "Upper bound for MLX processor visual resolution budgeting.", "RECALLFORGE_STORAGE": "Storage backend selector (currently lancedb).", "RECALLFORGE_STORE_PATH": "Path to RecallForge data store.", "RECALLFORGE_TRACE": "Enable verbose MCP server trace logging (1=true).", diff --git a/src/recallforge/backends/mlx_backend.py b/src/recallforge/backends/mlx_backend.py index e5b11b8..0ae204b 100644 --- a/src/recallforge/backends/mlx_backend.py +++ b/src/recallforge/backends/mlx_backend.py @@ -181,6 +181,8 @@ class MLXBackend(ModelBackend): _VIDEO_SAMPLE_FPS = 1.0 _VIDEO_MAX_FRAMES = 32 _VIDEO_FALLBACK_MAX_FRAMES = 8 + _VISION_MIN_PIXELS = 256 * 28 * 28 + _VISION_MAX_PIXELS = 1024 * 28 * 28 _DEFAULT_HEAVY_OP_CONCURRENCY = 1 # Captioning descriptors removed — they produced captions too generic for BM25. # See REC-129 for dedicated captioning model support. @@ -262,6 +264,17 @@ def __init__( self._VIDEO_FALLBACK_MAX_FRAMES, ), ) + self._VISION_MIN_PIXELS = self._resolve_positive_int_env( + "RECALLFORGE_MLX_MIN_PIXELS", + self._VISION_MIN_PIXELS, + ) + self._VISION_MAX_PIXELS = max( + self._VISION_MIN_PIXELS, + self._resolve_positive_int_env( + "RECALLFORGE_MLX_MAX_PIXELS", + self._VISION_MAX_PIXELS, + ), + ) def _resolve_heavy_op_concurrency(self) -> int: """Return the configured MLX heavy-op concurrency ceiling.""" @@ -321,6 +334,27 @@ def _hold_heavy_op(self, op_name: str): with self._get_heavy_op_gate().hold(op_name): yield + def _apply_processor_media_budgets(self, processor: Any) -> None: + """Apply conservative vision token budgets to a loaded processor.""" + targets = [processor] + image_processor = getattr(processor, "image_processor", None) + if image_processor is not None: + targets.append(image_processor) + + for target in targets: + for attr_name, value in ( + ("min_pixels", self._VISION_MIN_PIXELS), + ("max_pixels", self._VISION_MAX_PIXELS), + ): + if hasattr(target, attr_name): + setattr(target, attr_name, value) + + logger.debug( + "mlx_media_budgets min_pixels=%d max_pixels=%d", + self._VISION_MIN_PIXELS, + self._VISION_MAX_PIXELS, + ) + # ========================================================================= # Embedder # ========================================================================= @@ -355,6 +389,7 @@ def _load_embedder(self): self.EMBEDDER_MODEL, trust_remote_code=True, ) + self._apply_processor_media_budgets(self._embedder_processor) except Exception as exc: raise MLXEmbeddingError( f"Failed to load MLX embedder '{self.EMBEDDER_MODEL}'." @@ -865,6 +900,7 @@ def _load_captioner(self) -> None: self._captioner_model, self._captioner_processor = vlm_load( self.CAPTION_MODEL ) + self._apply_processor_media_budgets(self._captioner_processor) def _unload_captioner(self) -> None: """Free captioner memory when no longer needed.""" @@ -1716,6 +1752,7 @@ def _load_reranker(self): self.RERANKER_MODEL, trust_remote_code=True, ) + self._apply_processor_media_budgets(self._reranker_processor) finally: if hf_logging is not None and prev_hf_verbosity is not None: hf_logging.set_verbosity(prev_hf_verbosity) diff --git a/tests/test_mlx_reranker_prompt.py b/tests/test_mlx_reranker_prompt.py index acc750f..57baf66 100644 --- a/tests/test_mlx_reranker_prompt.py +++ b/tests/test_mlx_reranker_prompt.py @@ -81,6 +81,13 @@ def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=Tr return "PROCESSOR_TEMPLATE" +class _ProcessorWithImageProcessor: + def __init__(self): + self.min_pixels = None + self.max_pixels = None + self.image_processor = SimpleNamespace(min_pixels=None, max_pixels=None) + + class TestMLXRerankerPromptPreparation(unittest.TestCase): def _make_backend(self): backend = object.__new__(mlx_backend.MLXBackend) @@ -371,6 +378,26 @@ def fake_hold(name): np.array([[1.0, 0.0]], dtype=np.float32), ) + def test_apply_processor_media_budgets_updates_processor_and_image_processor(self): + backend = object.__new__(mlx_backend.MLXBackend) + backend._VISION_MIN_PIXELS = 100 + backend._VISION_MAX_PIXELS = 200 + processor = _ProcessorWithImageProcessor() + + backend._apply_processor_media_budgets(processor) + + self.assertEqual(processor.min_pixels, 100) + self.assertEqual(processor.max_pixels, 200) + self.assertEqual(processor.image_processor.min_pixels, 100) + self.assertEqual(processor.image_processor.max_pixels, 200) + + def test_apply_processor_media_budgets_ignores_missing_attrs(self): + backend = object.__new__(mlx_backend.MLXBackend) + backend._VISION_MIN_PIXELS = 100 + backend._VISION_MAX_PIXELS = 200 + + backend._apply_processor_media_budgets(SimpleNamespace()) + if __name__ == "__main__": unittest.main() From 0115586771c487cd7b93a63690c8f8bd4e45dae7 Mon Sep 17 00:00:00 2001 From: MollyAI Date: Fri, 27 Mar 2026 22:46:57 -0400 Subject: [PATCH 5/7] Add crash-safe benchmark smoke profile --- benchmarks/cross_modal_ablation.py | 89 +++++++++++++++++++++++- tests/test_cross_modal_benchmark_defs.py | 39 +++++++++++ 2 files changed, 127 insertions(+), 1 deletion(-) diff --git a/benchmarks/cross_modal_ablation.py b/benchmarks/cross_modal_ablation.py index 2bc3217..026f085 100644 --- a/benchmarks/cross_modal_ablation.py +++ b/benchmarks/cross_modal_ablation.py @@ -1871,12 +1871,44 @@ def _resolve_output_path(output_path: Optional[str], expansion_profile: str) -> ) +def _apply_smoke_profile_defaults( + smoke_profile: str, + stage_filters: Optional[List[str]], + max_queries_per_category: Optional[int], + rss_limit_mb: Optional[int], +) -> tuple[Optional[List[str]], Optional[int], Optional[int]]: + """Resolve smoke-profile defaults without overriding explicit caller choices.""" + if smoke_profile != "safe": + return stage_filters, max_queries_per_category, rss_limit_mb + + resolved_stage_filters = stage_filters or ["rrf"] + resolved_max_queries = max_queries_per_category or 1 + resolved_rss_limit = rss_limit_mb or 6144 + return resolved_stage_filters, resolved_max_queries, resolved_rss_limit + + +def _current_rss_mb() -> Optional[float]: + """Return current peak RSS in MB when available.""" + try: + import resource + except ImportError: + return None + + usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + if sys.platform == "darwin": + return usage / (1024 * 1024) + return usage / 1024.0 + + def _build_output_payload( categories: Dict[str, List[GroundTruth]], all_results: Dict[str, Dict[str, StageResult]], stages: List[Tuple[str, str]], *, expansion_profile: ExpansionProfile, + smoke_profile: str, + rss_limit_mb: Optional[int], + peak_rss_mb: Optional[float], indexed_items: int, run_status: str, interrupted: bool, @@ -1895,6 +1927,8 @@ def _build_output_payload( "expand_enabled": expansion_profile.expand, "media_query_probe_enabled": expansion_profile.enable_media_query_probe, "generated_query_expansion_enabled": expansion_profile.allow_generate_text, + "smoke_profile": smoke_profile, + "rss_limit_mb": rss_limit_mb, }, "run_status": run_status, "interrupted": interrupted, @@ -1908,6 +1942,9 @@ def _build_output_payload( "current_stage": current_stage, "current_category": current_category, }, + "telemetry": { + "peak_rss_mb": None if peak_rss_mb is None else round(peak_rss_mb, 1), + }, "corpus": { "text_docs": len(list((CORPUS_DIR / "text").glob("*.md"))), "images": len(list((CORPUS_DIR / "images").glob("*.png"))), @@ -2192,9 +2229,18 @@ def run_benchmark( stage_filters: Optional[List[str]] = None, max_queries_per_category: Optional[int] = None, expansion_profile: str = "caption_only", + smoke_profile: str = "off", + rss_limit_mb: Optional[int] = None, ) -> Dict[str, Any]: """Run the full cross-modal ablation benchmark.""" + stage_filters, max_queries_per_category, rss_limit_mb = _apply_smoke_profile_defaults( + smoke_profile, + stage_filters, + max_queries_per_category, + rss_limit_mb, + ) + categories = _group_queries( category_filters=category_filters, max_queries_per_category=max_queries_per_category, @@ -2208,6 +2254,7 @@ def run_benchmark( completed_stages: List[str] = [] current_stage_name: Optional[str] = None current_category_name: Optional[str] = None + peak_rss_mb = _current_rss_mb() def save_checkpoint( *, @@ -2221,6 +2268,9 @@ def save_checkpoint( all_results, stages, expansion_profile=profile, + smoke_profile=smoke_profile, + rss_limit_mb=rss_limit_mb, + peak_rss_mb=peak_rss_mb, indexed_items=indexed, run_status=run_status, interrupted=interrupted, @@ -2238,6 +2288,9 @@ def save_checkpoint( print(f"\nQuery categories: {', '.join(f'{k}({len(v)})' for k, v in categories.items())}") print(f"Stages: {', '.join(stage_name for stage_name, _ in stages)}") print(f"Expansion profile: {profile.name}") + print(f"Smoke profile: {smoke_profile}") + if rss_limit_mb is not None: + print(f"RSS limit: {rss_limit_mb} MB") print(f"Total queries: {sum(len(v) for v in categories.values())}") print(f"Corpus documents: {len(CORPUS_DOCS)}") @@ -2247,6 +2300,11 @@ def save_checkpoint( # Ingest corpus print("\nIndexing corpus...") indexed = ingest_corpus(backend, storage, collection, CORPUS_DIR) + peak_rss_mb = max(peak_rss_mb or 0.0, _current_rss_mb() or 0.0) + if rss_limit_mb is not None and peak_rss_mb and peak_rss_mb > rss_limit_mb: + raise RuntimeError( + f"RSS limit exceeded during indexing: peak {peak_rss_mb:.1f} MB > limit {rss_limit_mb} MB" + ) print(f"Indexed {indexed} items.\n") save_checkpoint(run_status="partial") @@ -2315,6 +2373,11 @@ def save_checkpoint( sr.asset_precision_at_5_sum += asset_metrics.precision_at_5 sr.asset_precision_at_10_sum += asset_metrics.precision_at_10 sr.latencies_ms.append(latency) + peak_rss_mb = max(peak_rss_mb or 0.0, _current_rss_mb() or 0.0) + if rss_limit_mb is not None and peak_rss_mb and peak_rss_mb > rss_limit_mb: + raise RuntimeError( + f"RSS limit exceeded: peak {peak_rss_mb:.1f} MB > limit {rss_limit_mb} MB" + ) # Track per-difficulty hits if gt.difficulty == "easy": @@ -2515,20 +2578,42 @@ def main(): "(current default), heuristic, or qwen" ), ) + parser.add_argument( + "--smoke-profile", + choices=["off", "safe"], + default="off", + help="Optional bounded smoke profile for safer local validation.", + ) + parser.add_argument( + "--rss-limit-mb", + type=int, + default=None, + help="Abort the benchmark if peak RSS exceeds this limit in MB.", + ) args = parser.parse_args() if args.dry_run: + resolved_stage_mode, resolved_max_queries, resolved_rss_limit = _apply_smoke_profile_defaults( + args.smoke_profile, + args.stage_mode, + args.max_queries_per_category, + args.rss_limit_mb, + ) # Validate query structure print(f"\n{'=' * 60}") print("DRY RUN - Query Structure Validation") print(f"{'=' * 60}") categories = _group_queries( category_filters=args.category, - max_queries_per_category=args.max_queries_per_category, + max_queries_per_category=resolved_max_queries, ) print(f"Total queries: {sum(len(v) for v in categories.values())}") print(f"Corpus documents: {len(CORPUS_DOCS)}") print(f"Expansion profile: {args.expansion_profile}") + print(f"Smoke profile: {args.smoke_profile}") + print(f"Resolved stages: {resolved_stage_mode or 'all'}") + if resolved_rss_limit is not None: + print(f"Resolved RSS limit: {resolved_rss_limit} MB") print(f"\nQueries by category:") for cat, queries in categories.items(): @@ -2587,6 +2672,8 @@ def main(): stage_filters=args.stage_mode, max_queries_per_category=args.max_queries_per_category, expansion_profile=args.expansion_profile, + smoke_profile=args.smoke_profile, + rss_limit_mb=args.rss_limit_mb, ) finally: diff --git a/tests/test_cross_modal_benchmark_defs.py b/tests/test_cross_modal_benchmark_defs.py index a2a5e2d..a0e0b31 100644 --- a/tests/test_cross_modal_benchmark_defs.py +++ b/tests/test_cross_modal_benchmark_defs.py @@ -249,6 +249,9 @@ def test_output_payload_tracks_partial_progress(self): {"Vector-only": {"text_to_text": stage_result}}, [("Vector-only", "vector")], expansion_profile=module._resolve_expansion_profile("caption_only"), + smoke_profile="safe", + rss_limit_mb=4096, + peak_rss_mb=512.4, indexed_items=74, run_status="partial", interrupted=True, @@ -260,8 +263,11 @@ def test_output_payload_tracks_partial_progress(self): self.assertEqual(payload["version"], "0.2.0") self.assertEqual(payload["configuration"]["expansion_profile"], "caption_only") + self.assertEqual(payload["configuration"]["smoke_profile"], "safe") + self.assertEqual(payload["configuration"]["rss_limit_mb"], 4096) self.assertFalse(payload["configuration"]["expand_enabled"]) self.assertTrue(payload["configuration"]["media_query_probe_enabled"]) + self.assertEqual(payload["telemetry"]["peak_rss_mb"], 512.4) self.assertEqual(payload["run_status"], "partial") self.assertTrue(payload["interrupted"]) self.assertEqual(payload["progress"]["indexed_items"], 74) @@ -325,6 +331,39 @@ def test_resolve_output_path_suffixes_non_default_profiles(self): self.assertTrue(default_path.endswith("cross_modal_ablation_results.json")) self.assertTrue(qwen_path.endswith("cross_modal_ablation_results_qwen.json")) + def test_apply_smoke_profile_defaults(self): + module = _load_cross_modal_ablation() + + stages, max_queries, rss_limit = module._apply_smoke_profile_defaults( + "safe", + None, + None, + None, + ) + self.assertEqual(stages, ["rrf"]) + self.assertEqual(max_queries, 1) + self.assertEqual(rss_limit, 6144) + + stages, max_queries, rss_limit = module._apply_smoke_profile_defaults( + "safe", + ["hybrid"], + 2, + 2048, + ) + self.assertEqual(stages, ["hybrid"]) + self.assertEqual(max_queries, 2) + self.assertEqual(rss_limit, 2048) + + stages, max_queries, rss_limit = module._apply_smoke_profile_defaults( + "off", + None, + None, + None, + ) + self.assertIsNone(stages) + self.assertIsNone(max_queries) + self.assertIsNone(rss_limit) + if __name__ == "__main__": unittest.main() From e4ee34c455029754c706d15973c89fd569b61d17 Mon Sep 17 00:00:00 2001 From: MollyAI Date: Fri, 27 Mar 2026 22:48:12 -0400 Subject: [PATCH 6/7] Document MLX safety controls --- docs/ENV_VARS.md | 54 +++++++++++++++++++++++++++++++++++------------- docs/RELEASE.md | 31 +++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 14 deletions(-) diff --git a/docs/ENV_VARS.md b/docs/ENV_VARS.md index de508a3..a2998d6 100644 --- a/docs/ENV_VARS.md +++ b/docs/ENV_VARS.md @@ -4,50 +4,76 @@ This is the canonical reference for all `RECALLFORGE_*` environment variables us ## Runtime selection -- `RECALLFORGE_BACKEND` +- `RECALLFORGE_BACKEND` Backend selector: `auto` (default), `torch`, `mlx`. -- `RECALLFORGE_MODE` +- `RECALLFORGE_MODE` Search mode: `embed` or `hybrid`. -- `RECALLFORGE_MLX_QUANTIZE` +- `RECALLFORGE_MLX_QUANTIZE` MLX quantization mode: `bf16` or `4bit`. -- `RECALLFORGE_DISABLE_MLX` +- `RECALLFORGE_DISABLE_MLX` Disable MLX backend probing when set to `1`. -- `RECALLFORGE_STORAGE` +## MLX safety knobs + +- `RECALLFORGE_MLX_HEAVY_OP_CONCURRENCY` + Concurrency ceiling for the heaviest MLX multimodal operations. Default is `1` for local safety. + +- `RECALLFORGE_MLX_VIDEO_SAMPLE_FPS` + Sampling rate for MLX raw-video processing. Lower values reduce memory pressure. + +- `RECALLFORGE_MLX_VIDEO_MAX_FRAMES` + Frame cap for MLX raw-video processing. The shipped default is intentionally conservative for local-agent use. + +- `RECALLFORGE_MLX_VIDEO_FALLBACK_MAX_FRAMES` + Frame cap for the ffmpeg-based frame-averaging fallback used when native video embedding is unavailable or downgraded. + +- `RECALLFORGE_MLX_MIN_PIXELS` + Lower bound for MLX processor visual resolution budgeting. + +- `RECALLFORGE_MLX_MAX_PIXELS` + Upper bound for MLX processor visual resolution budgeting. + +- `RECALLFORGE_STORAGE` Storage backend selector (currently `lancedb`). -- `RECALLFORGE_STORE_PATH` +- `RECALLFORGE_STORE_PATH` Path to the RecallForge data store. ## Search pipeline tuning -- `RECALLFORGE_OVERFETCH_FACTOR` +- `RECALLFORGE_OVERFETCH_FACTOR` Candidate overfetch multiplier before final trim. -- `RECALLFORGE_MAX_CANDIDATES` +- `RECALLFORGE_MAX_CANDIDATES` Hard cap for candidate pool size before reranking. -- `RECALLFORGE_RERANK_TOP_K` +- `RECALLFORGE_RERANK_TOP_K` Number of top RRF candidates to rerank. +- `RECALLFORGE_ENABLE_MEDIA_RERANKING` + Enable multimodal reranking for image/video-involved searches. Disabled by default. + +- `RECALLFORGE_ENABLE_RAW_VIDEO_QUERY_EMBEDDING` + Enable raw video query embedding. On MLX, RecallForge now defaults to safer caption/transcript-first retrieval unless you explicitly enable this. + ## Server behavior -- `RECALLFORGE_TRACE` +- `RECALLFORGE_TRACE` Enables trace logging for MCP tools when set to `1`. -- `RECALLFORGE_MCP_MAX_CONCURRENCY` +- `RECALLFORGE_MCP_MAX_CONCURRENCY` Maximum number of blocking MCP tool operations run concurrently. ## Storage/FTS internals -- `RECALLFORGE_BM25_FALLBACK_MAX_ROWS` +- `RECALLFORGE_BM25_FALLBACK_MAX_ROWS` Row limit used by BM25 fallback recovery paths. -- `RECALLFORGE_BULK_FLUSH_DOCS` +- `RECALLFORGE_BULK_FLUSH_DOCS` Batch flush threshold for document table writes. -- `RECALLFORGE_BULK_FLUSH_EMBEDDINGS` +- `RECALLFORGE_BULK_FLUSH_EMBEDDINGS` Batch flush threshold for embedding table writes. diff --git a/docs/RELEASE.md b/docs/RELEASE.md index 3b00a5d..4171e48 100644 --- a/docs/RELEASE.md +++ b/docs/RELEASE.md @@ -45,6 +45,26 @@ Then run the expanded benchmark: The benchmark now checkpoints to JSON as it runs. If the run is interrupted, the output file still contains partial results plus progress metadata. +For safer local validation after the MLX hardening work, prefer the bounded smoke lane first: + +```bash +.venv/bin/python benchmarks/cross_modal_ablation.py --backend mlx --smoke-profile safe --expansion-profile caption_only +``` + +That profile defaults to a smaller stage/query footprint and can enforce an RSS stop condition: + +```bash +.venv/bin/python benchmarks/cross_modal_ablation.py --backend mlx --smoke-profile safe --rss-limit-mb 6144 +``` + +The output JSON now records: + +- `configuration.smoke_profile` +- `configuration.rss_limit_mb` +- `telemetry.peak_rss_mb` + +Use the full benchmark only after the safe smoke completes cleanly on the target machine. + For query-expansion release decisions, compare at least these profiles: ```bash @@ -62,6 +82,17 @@ Profile meanings: When you omit `--output`, the benchmark now keeps profile-specific filenames for non-default runs, for example `cross_modal_ablation_results_qwen.json`. +### MLX safety notes + +- MLX heavy multimodal operations are now intentionally serialized by default via `RECALLFORGE_MLX_HEAVY_OP_CONCURRENCY=1`. +- On MLX, raw video query embedding is no longer the default hot path. RecallForge prefers caption/transcript-first retrieval unless `RECALLFORGE_ENABLE_RAW_VIDEO_QUERY_EMBEDDING=1` is set. +- The raw-video path now has explicit frame and pixel budget knobs: + - `RECALLFORGE_MLX_VIDEO_SAMPLE_FPS` + - `RECALLFORGE_MLX_VIDEO_MAX_FRAMES` + - `RECALLFORGE_MLX_VIDEO_FALLBACK_MAX_FRAMES` + - `RECALLFORGE_MLX_MIN_PIXELS` + - `RECALLFORGE_MLX_MAX_PIXELS` + ## 4. Tag and publish 1. Commit the release changes. From 08dd1245ef27522f5d5186fa729bf5e4d3513c32 Mon Sep 17 00:00:00 2001 From: MollyAI Date: Sat, 28 Mar 2026 14:20:15 -0400 Subject: [PATCH 7/7] Stabilize MLX native video processing --- docs/ENV_VARS.md | 3 + docs/RELEASE.md | 2 + src/recallforge/__init__.py | 1 + src/recallforge/backends/mlx_backend.py | 143 +++++++++++++++++++++--- tests/test_mlx_reranker_prompt.py | 67 ++++++++++- 5 files changed, 197 insertions(+), 19 deletions(-) diff --git a/docs/ENV_VARS.md b/docs/ENV_VARS.md index a2998d6..ff323c2 100644 --- a/docs/ENV_VARS.md +++ b/docs/ENV_VARS.md @@ -36,6 +36,9 @@ This is the canonical reference for all `RECALLFORGE_*` environment variables us - `RECALLFORGE_MLX_MAX_PIXELS` Upper bound for MLX processor visual resolution budgeting. +- `RECALLFORGE_ENABLE_MLX_NATIVE_VIDEO_PROCESSING` + Enable qwen-vl-utils native video decoding on MLX. Disabled by default for local safety; if you opt in, prefer `FORCE_QWENVL_VIDEO_READER=torchcodec`. + - `RECALLFORGE_STORAGE` Storage backend selector (currently `lancedb`). diff --git a/docs/RELEASE.md b/docs/RELEASE.md index 4171e48..e21e14a 100644 --- a/docs/RELEASE.md +++ b/docs/RELEASE.md @@ -86,6 +86,8 @@ When you omit `--output`, the benchmark now keeps profile-specific filenames for - MLX heavy multimodal operations are now intentionally serialized by default via `RECALLFORGE_MLX_HEAVY_OP_CONCURRENCY=1`. - On MLX, raw video query embedding is no longer the default hot path. RecallForge prefers caption/transcript-first retrieval unless `RECALLFORGE_ENABLE_RAW_VIDEO_QUERY_EMBEDDING=1` is set. +- On MLX, qwen-vl-utils native video decoding is now also opt-in. RecallForge defaults to frame/caption fallbacks unless `RECALLFORGE_ENABLE_MLX_NATIVE_VIDEO_PROCESSING=1` is set. +- If you do opt back into native MLX video decoding, prefer `FORCE_QWENVL_VIDEO_READER=torchcodec` per Qwen's upstream guidance. - The raw-video path now has explicit frame and pixel budget knobs: - `RECALLFORGE_MLX_VIDEO_SAMPLE_FPS` - `RECALLFORGE_MLX_VIDEO_MAX_FRAMES` diff --git a/src/recallforge/__init__.py b/src/recallforge/__init__.py index 6a6652b..5e4372a 100644 --- a/src/recallforge/__init__.py +++ b/src/recallforge/__init__.py @@ -50,6 +50,7 @@ def _has_torch() -> bool: "RECALLFORGE_MLX_VIDEO_FALLBACK_MAX_FRAMES": "Frame cap for ffmpeg frame-averaging fallback when native video embedding is unavailable.", "RECALLFORGE_MLX_MIN_PIXELS": "Lower bound for MLX processor visual resolution budgeting.", "RECALLFORGE_MLX_MAX_PIXELS": "Upper bound for MLX processor visual resolution budgeting.", + "RECALLFORGE_ENABLE_MLX_NATIVE_VIDEO_PROCESSING": "Enable qwen-vl-utils native video decoding for MLX; disabled by default for local safety.", "RECALLFORGE_STORAGE": "Storage backend selector (currently lancedb).", "RECALLFORGE_STORE_PATH": "Path to RecallForge data store.", "RECALLFORGE_TRACE": "Enable verbose MCP server trace logging (1=true).", diff --git a/src/recallforge/backends/mlx_backend.py b/src/recallforge/backends/mlx_backend.py index 0ae204b..c3bc797 100644 --- a/src/recallforge/backends/mlx_backend.py +++ b/src/recallforge/backends/mlx_backend.py @@ -184,6 +184,7 @@ class MLXBackend(ModelBackend): _VISION_MIN_PIXELS = 256 * 28 * 28 _VISION_MAX_PIXELS = 1024 * 28 * 28 _DEFAULT_HEAVY_OP_CONCURRENCY = 1 + _ENABLE_NATIVE_VIDEO_PROCESSING_DEFAULT = False # Captioning descriptors removed — they produced captions too generic for BM25. # See REC-129 for dedicated captioning model support. @@ -317,6 +318,54 @@ def _resolve_positive_float_env(self, name: str, default: float) -> float: return default return value if value > 0 else default + def _resolve_bool_env(self, name: str, default: bool) -> bool: + """Read a boolean env var with graceful fallback.""" + raw = os.environ.get(name) + if raw is None: + return default + normalized = raw.strip().lower() + if normalized in {"1", "true", "yes", "on"}: + return True + if normalized in {"0", "false", "no", "off"}: + return False + logger.warning("Invalid %s=%r; using %s", name, raw, default) + return default + + def _native_video_processing_enabled(self) -> bool: + """Return whether MLX should use qwen-vl-utils native video processing.""" + enabled = self._resolve_bool_env( + "RECALLFORGE_ENABLE_MLX_NATIVE_VIDEO_PROCESSING", + self._ENABLE_NATIVE_VIDEO_PROCESSING_DEFAULT, + ) + if not enabled: + if not getattr(self, "_warned_native_video_disabled", False): + logger.info( + "MLX native video processing is disabled by default for local safety; " + "using frame/caption fallbacks. Set " + "RECALLFORGE_ENABLE_MLX_NATIVE_VIDEO_PROCESSING=1 and prefer " + "FORCE_QWENVL_VIDEO_READER=torchcodec to opt in." + ) + self._warned_native_video_disabled = True + return False + + if getattr(self, "_warned_native_video_enabled", False): + return True + + reader = (os.environ.get("FORCE_QWENVL_VIDEO_READER") or "").strip().lower() + if reader == "torchcodec": + logger.info( + "MLX native video processing enabled with FORCE_QWENVL_VIDEO_READER=torchcodec." + ) + else: + configured = reader or "auto" + logger.warning( + "MLX native video processing is enabled with FORCE_QWENVL_VIDEO_READER=%s. " + "Qwen upstream recommends torchcodec for the most stable video loading.", + configured, + ) + self._warned_native_video_enabled = True + return True + def _get_heavy_op_gate(self) -> _HeavyOpGate: """Return the shared gate used to limit overlapping MLX heavy ops.""" global _HEAVY_OP_GATE, _HEAVY_OP_GATE_LIMIT @@ -355,6 +404,18 @@ def _apply_processor_media_budgets(self, processor: Any) -> None: self._VISION_MAX_PIXELS, ) + def _call_media_processor(self, processor: Any, **kwargs): + """Call a media processor without repeating qwen-vl-utils resizing.""" + proc_kwargs = dict(kwargs) + proc_kwargs.setdefault("do_resize", False) + try: + return processor(**proc_kwargs) + except TypeError as exc: + if "do_resize" not in str(exc): + raise + proc_kwargs.pop("do_resize", None) + return processor(**proc_kwargs) + # ========================================================================= # Embedder # ========================================================================= @@ -776,7 +837,7 @@ def embed_images(self, image_paths: List[str]) -> np.ndarray: messages_batch.append([{ "role": "user", "content": [ - {"type": "image", "image": path}, + self._image_content(path), {"type": "text", "text": "Describe this image."}, ], }]) @@ -812,7 +873,8 @@ def embed_images(self, image_paths: List[str]) -> np.ndarray: try: # Image processor requires PyTorch tensors, convert to MLX after - inputs = self._embedder_processor( + inputs = self._call_media_processor( + self._embedder_processor, text=chat_texts, images=image_inputs, return_tensors="pt", padding=True, ) @@ -868,15 +930,30 @@ def embed_images(self, image_paths: List[str]) -> np.ndarray: return embeddings + def _image_content(self, path: str) -> dict: + """Build an image content block with conservative qwen-vl-utils budgets.""" + return { + "type": "image", + "image": path, + "min_pixels": self._VISION_MIN_PIXELS, + "max_pixels": self._VISION_MAX_PIXELS, + } + def _video_content(self, path: str) -> dict: """Build a video content dict with adaptive frame sampling. Uses fps-based sampling (1 frame/sec) so longer videos get more frames. Caps at _VIDEO_MAX_FRAMES to bound memory on very long videos. - A 30s video → 30 frames. A 10min video → 128 frames (capped). + A 30s video → up to 30 frames. Longer videos clamp to the configured cap. """ - return {"type": "video", "video": path, "fps": self._VIDEO_SAMPLE_FPS, - "max_frames": self._VIDEO_MAX_FRAMES} + return { + "type": "video", + "video": path, + "fps": self._VIDEO_SAMPLE_FPS, + "max_frames": self._VIDEO_MAX_FRAMES, + "min_pixels": self._VISION_MIN_PIXELS, + "max_pixels": self._VISION_MAX_PIXELS, + } def embed_video(self, video_path: str) -> np.ndarray: """Embed a single video.""" @@ -1045,15 +1122,18 @@ def embed_videos(self, video_paths: List[str]) -> np.ndarray: embeddings: List[np.ndarray] = [] for path in video_paths: - try: - embedding = self._embed_video_native(path, num_layers) - except Exception as exc: - logger.warning( - "native_video_embedding failed for %s: %s; falling back to frame embeddings", - path, - exc, - ) + if not self._native_video_processing_enabled(): embedding = self._embed_video_via_frames(path) + else: + try: + embedding = self._embed_video_native(path, num_layers) + except Exception as exc: + logger.warning( + "native_video_embedding failed for %s: %s; falling back to frame embeddings", + path, + exc, + ) + embedding = self._embed_video_via_frames(path) embeddings.append(embedding) @@ -1110,7 +1190,8 @@ def _embed_video_native(self, path: str, num_layers: int) -> np.ndarray: normalized_video_kwargs["fps"] = fps_value[0] if fps_value else None try: - inputs = self._embedder_processor( + inputs = self._call_media_processor( + self._embedder_processor, text=[chat_text], videos=video_inputs, return_tensors="pt", @@ -1443,7 +1524,7 @@ def _build_reranker_messages( """Build the multimodal chat messages used for reranker prompting.""" query_content: list = [{"type": "text", "text": ":"}] if query_image_path: - query_content.append({"type": "image", "image": self._as_file_uri(query_image_path)}) + query_content.append(self._image_content(self._as_file_uri(query_image_path))) if query: query_content.append({"type": "text", "text": query}) elif query_video_path: @@ -1455,7 +1536,7 @@ def _build_reranker_messages( doc_content: list = [{"type": "text", "text": "\n:"}] if image_path: - doc_content.append({"type": "image", "image": self._as_file_uri(image_path)}) + doc_content.append(self._image_content(self._as_file_uri(image_path))) if document: doc_content.append({"type": "text", "text": document}) elif video_path: @@ -1480,6 +1561,23 @@ def _build_reranker_messages( }, ] + def _drop_video_blocks(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Remove video blocks so callers can degrade to text/image-only processing.""" + filtered_messages: List[Dict[str, Any]] = [] + for message in messages: + content = message.get("content") + if not isinstance(content, list): + filtered_messages.append(dict(message)) + continue + filtered_message = dict(message) + filtered_message["content"] = [ + block + for block in content + if not (isinstance(block, dict) and block.get("type") == "video") + ] + filtered_messages.append(filtered_message) + return filtered_messages + def _build_reranker_processor_inputs( self, prompt: str, @@ -1494,6 +1592,17 @@ def _build_reranker_processor_inputs( for message in user_messages: content_blocks.extend(message.get("content", [])) + has_video = any( + block.get("type") == "video" + for block in content_blocks + if isinstance(block, dict) + ) + if has_video and not self._native_video_processing_enabled(): + user_messages = self._drop_video_blocks(user_messages) + content_blocks = [] + for message in user_messages: + content_blocks.extend(message.get("content", [])) + has_vision = any( block.get("type") in {"image", "video"} for block in content_blocks @@ -1520,7 +1629,7 @@ def _build_reranker_processor_inputs( proc_kwargs["videos"] = video_inputs proc_kwargs.update(normalized_video_kwargs) - pt_inputs = self._reranker_processor(**proc_kwargs) + pt_inputs = self._call_media_processor(self._reranker_processor, **proc_kwargs) return { key: value.numpy() if hasattr(value, "numpy") else value for key, value in pt_inputs.items() diff --git a/tests/test_mlx_reranker_prompt.py b/tests/test_mlx_reranker_prompt.py index 57baf66..a4fd333 100644 --- a/tests/test_mlx_reranker_prompt.py +++ b/tests/test_mlx_reranker_prompt.py @@ -139,6 +139,8 @@ def fake_process_vision_info(conversations, return_video_kwargs=True): content = user_messages[0]["content"] image_blocks = [block for block in content if block.get("type") == "image"] self.assertEqual([block["image"] for block in image_blocks], expected_image_uris) + self.assertTrue(all("min_pixels" in block for block in image_blocks)) + self.assertTrue(all("max_pixels" in block for block in image_blocks)) return ["query_pixels", "doc_pixels"], None, {} fake_module = types.SimpleNamespace(process_vision_info=fake_process_vision_info) @@ -153,6 +155,7 @@ def fake_process_vision_info(conversations, return_video_kwargs=True): self.assertEqual(call["images"], ["query_pixels", "doc_pixels"]) self.assertEqual(call["return_tensors"], "pt") self.assertTrue(call["padding"]) + self.assertFalse(call["do_resize"]) def test_text_only_messages_skip_vision_preprocessing(self): backend = self._make_backend() @@ -171,6 +174,25 @@ def test_text_only_messages_skip_vision_preprocessing(self): {"text": "PROMPT", "return_tensors": "np"}, ) + def test_video_reranker_inputs_degrade_to_text_when_native_video_disabled(self): + backend = self._make_backend() + messages = backend._build_reranker_messages( + query="caption fallback", + document="doc body", + instruction="rank this", + video_path="docs/example.mp4", + ) + + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("RECALLFORGE_ENABLE_MLX_NATIVE_VIDEO_PROCESSING", None) + inputs = backend._build_reranker_processor_inputs("PROMPT", messages) + + self.assertIn("input_ids", inputs) + self.assertEqual( + backend._reranker_processor.calls[0], + {"text": "PROMPT", "return_tensors": "np"}, + ) + def test_apply_chat_template_falls_back_to_tokenizer_template(self): backend = object.__new__(mlx_backend.MLXBackend) processor = _ProcessorWithoutChatTemplate() @@ -238,13 +260,33 @@ def _raise_native(_path, _num_layers): backend._embed_video_native = _raise_native backend._embed_video_via_frames = lambda _path: np.array([0.6, 0.8], dtype=np.float32) - embeddings = backend.embed_videos(["clip.mp4"]) + with patch.dict(os.environ, {"RECALLFORGE_ENABLE_MLX_NATIVE_VIDEO_PROCESSING": "1"}): + embeddings = backend.embed_videos(["clip.mp4"]) np.testing.assert_allclose( embeddings, np.array([[0.6, 0.8]], dtype=np.float32), ) + def test_embed_videos_use_frame_fallback_by_default_when_native_video_disabled(self): + backend = object.__new__(mlx_backend.MLXBackend) + backend._validate_video_paths = lambda paths: paths + backend._load_embedder = lambda: None + backend._get_embedder_num_layers = lambda: 2 + native_calls = [] + backend._embed_video_native = lambda *_args, **_kwargs: native_calls.append("native") + backend._embed_video_via_frames = lambda _path: np.array([0.6, 0.8], dtype=np.float32) + + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("RECALLFORGE_ENABLE_MLX_NATIVE_VIDEO_PROCESSING", None) + embeddings = backend.embed_videos(["clip.mp4"]) + + self.assertEqual(native_calls, []) + np.testing.assert_allclose( + embeddings, + np.array([[0.6, 0.8]], dtype=np.float32), + ) + def test_embed_video_via_frames_averages_and_normalizes_frame_embeddings(self): backend = object.__new__(mlx_backend.MLXBackend) backend._VIDEO_MAX_FRAMES = 32 @@ -370,7 +412,8 @@ def fake_hold(name): backend._get_embedder_num_layers = lambda: 2 backend._embed_video_native = lambda _path, _num_layers: np.array([1.0, 0.0], dtype=np.float32) - embeddings = backend.embed_videos(["clip.mp4"]) + with patch.dict(os.environ, {"RECALLFORGE_ENABLE_MLX_NATIVE_VIDEO_PROCESSING": "1"}): + embeddings = backend.embed_videos(["clip.mp4"]) self.assertEqual(calls, ["embed_videos"]) np.testing.assert_allclose( @@ -398,6 +441,26 @@ def test_apply_processor_media_budgets_ignores_missing_attrs(self): backend._apply_processor_media_budgets(SimpleNamespace()) + def test_call_media_processor_disables_duplicate_resize(self): + backend = object.__new__(mlx_backend.MLXBackend) + + class _Processor: + def __init__(self): + self.calls = [] + + def __call__(self, **kwargs): + self.calls.append(kwargs) + return {"ok": True} + + processor = _Processor() + result = backend._call_media_processor(processor, text=["hello"], return_tensors="pt") + + self.assertEqual(result, {"ok": True}) + self.assertEqual( + processor.calls[0], + {"text": ["hello"], "return_tensors": "pt", "do_resize": False}, + ) + if __name__ == "__main__": unittest.main()