Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 207 additions & 59 deletions benchmarks/cross_modal_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import json
import math
import os
import re
import shutil
import sys
import tempfile
Expand Down Expand Up @@ -1365,6 +1366,13 @@ class StageResult:
mrr_sum: float = 0.0
precision_at_5_sum: float = 0.0
precision_at_10_sum: float = 0.0
asset_hits_at_1: int = 0
asset_hits_at_5: int = 0
asset_hits_at_10: int = 0
asset_ndcg_sum: float = 0.0
asset_mrr_sum: float = 0.0
asset_precision_at_5_sum: float = 0.0
asset_precision_at_10_sum: float = 0.0
latencies_ms: list = field(default_factory=list)
per_query_results: list = field(default_factory=list)

Expand Down Expand Up @@ -1407,6 +1415,34 @@ def precision_at_5(self) -> float:
def precision_at_10(self) -> float:
return self.precision_at_10_sum / max(self.total_queries, 1)

@property
def asset_recall_at_1(self) -> float:
return self.asset_hits_at_1 / max(self.total_queries, 1)

@property
def asset_recall_at_5(self) -> float:
return self.asset_hits_at_5 / max(self.total_queries, 1)

@property
def asset_recall_at_10(self) -> float:
return self.asset_hits_at_10 / max(self.total_queries, 1)

@property
def asset_ndcg_at_10(self) -> float:
return self.asset_ndcg_sum / max(self.total_queries, 1)

@property
def asset_mrr(self) -> float:
return self.asset_mrr_sum / max(self.total_queries, 1)

@property
def asset_precision_at_5(self) -> float:
return self.asset_precision_at_5_sum / max(self.total_queries, 1)

@property
def asset_precision_at_10(self) -> float:
return self.asset_precision_at_10_sum / max(self.total_queries, 1)

@property
def p50_ms(self) -> float:
if not self.latencies_ms:
Expand Down Expand Up @@ -1462,6 +1498,117 @@ def _ndcg(relevances: List[float], k: int = 10) -> float:
return dcg / ideal if ideal > 0 else 0.0


@dataclass
class EvaluationMetrics:
hit_at_1: bool
hit_at_5: bool
hit_at_10: bool
ndcg: float
rr: float
precision_at_5: float
precision_at_10: float


def _normalize_benchmark_path(path: str, corpus_dir: Path) -> str:
"""Normalize benchmark filepaths to corpus-relative paths when possible."""
raw = str(path or "").strip()
if not raw:
return ""

if raw.startswith("recallforge://"):
without_scheme = raw[len("recallforge://"):]
_, _, raw = without_scheme.partition("/")
raw = raw or without_scheme

raw = re.sub(r"\s+", " ", raw).strip()
candidate = Path(raw)
if candidate.is_absolute():
try:
return candidate.resolve().relative_to(corpus_dir.resolve()).as_posix()
except Exception:
return candidate.resolve().as_posix()
return raw.lstrip("./")


def _memory_key_for_path(path: str, corpus_dir: Path) -> str:
"""Map a result or ground-truth path to its canonical parent memory path."""
normalized = _normalize_benchmark_path(path, corpus_dir)
if not normalized:
return ""
if normalized.endswith(".transcript.json"):
return normalized[: -len(".transcript.json")] + ".mp4"
return normalized.split("::", 1)[0]


def _score_relevances(relevances: List[float]) -> EvaluationMetrics:
"""Compute benchmark metrics for a ranked relevance vector."""
first_hit_rank = next((i + 1 for i, rel in enumerate(relevances[:10]) if rel > 0), None)
hit_1 = any(rel > 0 for rel in relevances[:1])
hit_5 = any(rel > 0 for rel in relevances[:5])
hit_10 = any(rel > 0 for rel in relevances[:10])
ndcg = _ndcg(relevances, 10)
rr = 1.0 / first_hit_rank if first_hit_rank else 0.0

max_rel = 2.0
prec_5 = (
sum(relevances[:5]) / (5 * max_rel)
if len(relevances) >= 5
else (sum(relevances) / (len(relevances) * max_rel) if relevances else 0.0)
)
prec_10 = (
sum(relevances[:10]) / (10 * max_rel)
if len(relevances) >= 10
else (sum(relevances) / (len(relevances) * max_rel) if relevances else 0.0)
)
return EvaluationMetrics(
hit_at_1=hit_1,
hit_at_5=hit_5,
hit_at_10=hit_10,
ndcg=ndcg,
rr=rr,
precision_at_5=prec_5,
precision_at_10=prec_10,
)


def evaluate_results_detailed(
results: List[Dict[str, Any]],
gt: GroundTruth,
corpus_dir: Path,
) -> Dict[str, EvaluationMetrics]:
"""Evaluate results at both parent-memory and raw asset granularity."""
relevant_asset_scores: Dict[str, int] = {}
relevant_memory_scores: Dict[str, int] = {}

for path in gt.relevant_paths:
normalized_path = _normalize_benchmark_path(path, corpus_dir)
if normalized_path:
relevant_asset_scores[normalized_path] = max(
relevant_asset_scores.get(normalized_path, 0),
gt.get_relevance_score(path),
)
memory_key = _memory_key_for_path(path, corpus_dir)
if memory_key:
relevant_memory_scores[memory_key] = max(
relevant_memory_scores.get(memory_key, 0),
gt.get_relevance_score(path),
)

memory_relevances: List[float] = []
asset_relevances: List[float] = []
for result in results[:10]:
filepath = result.get("filepath", "")
normalized_result = _normalize_benchmark_path(filepath, corpus_dir)
asset_relevances.append(float(relevant_asset_scores.get(normalized_result, 0)))
memory_key = _memory_key_for_path(filepath, corpus_dir)
memory_relevances.append(float(relevant_memory_scores.get(memory_key, 0)))

return {
"memory": _score_relevances(memory_relevances),
"asset": _score_relevances(asset_relevances),
}


def evaluate_results(
results: List[Dict[str, Any]],
gt: GroundTruth,
Expand All @@ -1471,42 +1618,16 @@ def evaluate_results(

Returns: (hit@1, hit@5, hit@10, ndcg@10, reciprocal_rank, precision@5, precision@10)
"""
# Normalize GT paths to absolute
gt_paths_abs = set()
for p in gt.relevant_paths:
abs_path = str((corpus_dir / p).resolve())
gt_paths_abs.add(abs_path)
gt_paths_abs.add(Path(p).stem)

def get_relevance_score(result: Dict) -> int:
fp = result.get("filepath", "")
# Check absolute path match
for gp in gt.relevant_paths:
if gp in fp or Path(gp).stem in fp:
return gt.get_relevance_score(gp)
return 0

# Build relevance vector with graded scores
relevances = []
first_hit_rank = None
for i, r in enumerate(results[:10]):
rel = get_relevance_score(r)
relevances.append(float(rel))
if rel > 0 and first_hit_rank is None:
first_hit_rank = i + 1

hit_1 = any(r > 0 for r in relevances[:1])
hit_5 = any(r > 0 for r in relevances[:5])
hit_10 = any(r > 0 for r in relevances[:10])
ndcg = _ndcg(relevances, 10)
rr = 1.0 / first_hit_rank if first_hit_rank else 0.0

# Precision@K with graded relevance (normalize by max relevance)
max_rel = 2.0 # Maximum relevance score
prec_5 = sum(relevances[:5]) / (5 * max_rel) if len(relevances) >= 5 else sum(relevances) / (len(relevances) * max_rel) if relevances else 0.0
prec_10 = sum(relevances[:10]) / (10 * max_rel) if len(relevances) >= 10 else sum(relevances) / (len(relevances) * max_rel) if relevances else 0.0

return hit_1, hit_5, hit_10, ndcg, rr, prec_5, prec_10
memory_metrics = evaluate_results_detailed(results, gt, corpus_dir)["memory"]
return (
memory_metrics.hit_at_1,
memory_metrics.hit_at_5,
memory_metrics.hit_at_10,
memory_metrics.ndcg,
memory_metrics.rr,
memory_metrics.precision_at_5,
memory_metrics.precision_at_10,
)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -1736,6 +1857,15 @@ def _build_output_payload(
"mrr": None if sr.skipped else round(sr.mrr, 4),
"p50_ms": None if sr.skipped else round(sr.p50_ms, 1),
"p95_ms": None if sr.skipped else round(sr.p95_ms, 1),
"asset_level": {
"recall_at_1": None if sr.skipped else round(sr.asset_recall_at_1, 4),
"recall_at_5": None if sr.skipped else round(sr.asset_recall_at_5, 4),
"recall_at_10": None if sr.skipped else round(sr.asset_recall_at_10, 4),
"precision_at_5": None if sr.skipped else round(sr.asset_precision_at_5, 4),
"precision_at_10": None if sr.skipped else round(sr.asset_precision_at_10, 4),
"ndcg_at_10": None if sr.skipped else round(sr.asset_ndcg_at_10, 4),
"mrr": None if sr.skipped else round(sr.asset_mrr, 4),
},
"total_queries": sr.total_queries,
"by_difficulty": {
"easy": {
Expand Down Expand Up @@ -2057,27 +2187,36 @@ def save_checkpoint(
backend, storage, gt,
collection, effective_mode,
)
h1, h5, h10, ndcg, rr, prec_5, prec_10 = evaluate_results(results, gt, CORPUS_DIR)

sr.hits_at_1 += int(h1)
sr.hits_at_5 += int(h5)
sr.hits_at_10 += int(h10)
sr.ndcg_sum += ndcg
sr.mrr_sum += rr
sr.precision_at_5_sum += prec_5
sr.precision_at_10_sum += prec_10
eval_detail = evaluate_results_detailed(results, gt, CORPUS_DIR)
memory_metrics = eval_detail["memory"]
asset_metrics = eval_detail["asset"]

sr.hits_at_1 += int(memory_metrics.hit_at_1)
sr.hits_at_5 += int(memory_metrics.hit_at_5)
sr.hits_at_10 += int(memory_metrics.hit_at_10)
sr.ndcg_sum += memory_metrics.ndcg
sr.mrr_sum += memory_metrics.rr
sr.precision_at_5_sum += memory_metrics.precision_at_5
sr.precision_at_10_sum += memory_metrics.precision_at_10
sr.asset_hits_at_1 += int(asset_metrics.hit_at_1)
sr.asset_hits_at_5 += int(asset_metrics.hit_at_5)
sr.asset_hits_at_10 += int(asset_metrics.hit_at_10)
sr.asset_ndcg_sum += asset_metrics.ndcg
sr.asset_mrr_sum += asset_metrics.rr
sr.asset_precision_at_5_sum += asset_metrics.precision_at_5
sr.asset_precision_at_10_sum += asset_metrics.precision_at_10
sr.latencies_ms.append(latency)

# Track per-difficulty hits
if gt.difficulty == "easy":
sr.easy_hits_at_1 += int(h1)
sr.easy_hits_at_5 += int(h5)
sr.easy_hits_at_1 += int(memory_metrics.hit_at_1)
sr.easy_hits_at_5 += int(memory_metrics.hit_at_5)
elif gt.difficulty == "medium":
sr.medium_hits_at_1 += int(h1)
sr.medium_hits_at_5 += int(h5)
sr.medium_hits_at_1 += int(memory_metrics.hit_at_1)
sr.medium_hits_at_5 += int(memory_metrics.hit_at_5)
elif gt.difficulty == "hard":
sr.hard_hits_at_1 += int(h1)
sr.hard_hits_at_5 += int(h5)
sr.hard_hits_at_1 += int(memory_metrics.hit_at_1)
sr.hard_hits_at_5 += int(memory_metrics.hit_at_5)

# Store per-query result with audit trail for post-hoc analysis
sr.per_query_results.append({
Expand All @@ -2087,13 +2226,22 @@ def save_checkpoint(
"relevant_paths": gt.relevant_paths,
"difficulty": gt.difficulty,
"is_negative_control": gt.is_negative_control,
"hit_at_1": h1,
"hit_at_5": h5,
"hit_at_10": h10,
"ndcg": ndcg,
"mrr": rr,
"precision_at_5": prec_5,
"precision_at_10": prec_10,
"hit_at_1": memory_metrics.hit_at_1,
"hit_at_5": memory_metrics.hit_at_5,
"hit_at_10": memory_metrics.hit_at_10,
"ndcg": memory_metrics.ndcg,
"mrr": memory_metrics.rr,
"precision_at_5": memory_metrics.precision_at_5,
"precision_at_10": memory_metrics.precision_at_10,
"asset_level": {
"hit_at_1": asset_metrics.hit_at_1,
"hit_at_5": asset_metrics.hit_at_5,
"hit_at_10": asset_metrics.hit_at_10,
"ndcg": asset_metrics.ndcg,
"mrr": asset_metrics.rr,
"precision_at_5": asset_metrics.precision_at_5,
"precision_at_10": asset_metrics.precision_at_10,
},
"latency_ms": latency,
"results": results,
})
Expand All @@ -2113,7 +2261,7 @@ def save_checkpoint(
all_results[stage_name][cat_name] = sr
print(f" {stage_name} for {cat_name} ({len(queries)}q)... "
f"R@1={sr.recall_at_1:.1%} R@5={sr.recall_at_5:.1%} "
f"R@10={sr.recall_at_10:.1%} P@5={sr.precision_at_5:.3f} "
f"R@10={sr.recall_at_10:.1%} AssetR@1={sr.asset_recall_at_1:.1%} P@5={sr.precision_at_5:.3f} "
f"NDCG@10={sr.ndcg_at_10:.3f} MRR={sr.mrr:.3f}")
save_checkpoint(run_status="partial")

Expand Down
35 changes: 35 additions & 0 deletions tests/test_cross_modal_benchmark_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ def test_video_frame_assets_count_as_hits_for_parent_video_ground_truth(self):
self.assertEqual(prec_5, 1.0)
self.assertEqual(prec_10, 1.0)

detailed = module.evaluate_results_detailed([result], gt, module.CORPUS_DIR)
self.assertTrue(detailed["memory"].hit_at_1)
self.assertFalse(detailed["asset"].hit_at_1)

def test_video_transcript_assets_count_as_hits_for_parent_video_ground_truth(self):
module = _load_cross_modal_ablation()

Expand All @@ -196,6 +200,10 @@ def test_video_transcript_assets_count_as_hits_for_parent_video_ground_truth(sel
self.assertEqual(prec_5, 1.0)
self.assertEqual(prec_10, 1.0)

detailed = module.evaluate_results_detailed([result], gt, module.CORPUS_DIR)
self.assertTrue(detailed["memory"].hit_at_1)
self.assertFalse(detailed["asset"].hit_at_1)

def test_output_payload_tracks_partial_progress(self):
module = _load_cross_modal_ablation()

Expand All @@ -215,6 +223,26 @@ def test_output_payload_tracks_partial_progress(self):
precision_at_5_sum=1.0,
precision_at_10_sum=1.0,
)
stage_result.asset_hits_at_1 = 0
stage_result.asset_hits_at_5 = 0
stage_result.asset_hits_at_10 = 0
stage_result.asset_ndcg_sum = 0.0
stage_result.asset_mrr_sum = 0.0
stage_result.asset_precision_at_5_sum = 0.0
stage_result.asset_precision_at_10_sum = 0.0
stage_result.per_query_results.append(
{
"query": module.TEXT_TO_TEXT[0].query,
"hit_at_1": True,
"hit_at_5": True,
"hit_at_10": True,
"asset_level": {
"hit_at_1": False,
"hit_at_5": False,
"hit_at_10": False,
},
}
)

payload = module._build_output_payload(
categories,
Expand All @@ -241,6 +269,13 @@ def test_output_payload_tracks_partial_progress(self):
payload["stages"]["Vector-only"]["text_to_text"]["recall_at_1"],
1.0,
)
self.assertEqual(
payload["stages"]["Vector-only"]["text_to_text"]["asset_level"]["recall_at_1"],
0.0,
)
self.assertFalse(
payload["stages"]["Vector-only"]["text_to_text"]["per_query_results"][0]["asset_level"]["hit_at_1"]
)


if __name__ == "__main__":
Expand Down
Loading