diff --git a/experiment/benchmarks/bbh.py b/experiment/benchmarks/bbh.py index 582329b..5066ab0 100644 --- a/experiment/benchmarks/bbh.py +++ b/experiment/benchmarks/bbh.py @@ -69,20 +69,51 @@ def _format_choices(choices: dict[str, Any]) -> str: return "\n".join(lines) -def format_prompt(question: str, choices: dict[str, Any]) -> str: - """Build the model-facing prompt for a BBH multiple-choice item.""" +def _has_choices(choices: Any) -> bool: + """True if the dataset row actually exposed a multiple-choice list.""" + if not choices: + return False + if not isinstance(choices, dict): + return False + return bool(choices.get("text") or choices.get("label")) + + +def format_prompt(question: str, choices: dict[str, Any] | None) -> str: + """Build the model-facing prompt for a BBH item. + + Two shapes: + - MCQ (has `choices`) → "Answer: X" letter convention. + - Yes/No (no `choices`, e.g. causal_judgement, web_of_lies, navigate) + → "Answer: Yes" or "Answer: No". + """ + if _has_choices(choices): + return ( + f"{question.strip()}\n\n" + f"Choices:\n{_format_choices(choices)}\n\n" + "Reply with your reasoning, then end with exactly one line:\n" + "Answer: X\n" + "where X is the letter of the correct choice (A, B, C, ...)." + ) return ( f"{question.strip()}\n\n" - f"Choices:\n{_format_choices(choices)}\n\n" "Reply with your reasoning, then end with exactly one line:\n" - "Answer: X\n" - "where X is the letter of the correct choice (A, B, C, ...)." + "Answer: Yes\n" + "or\n" + "Answer: No" ) def normalize_gold(target: str) -> str: - """BBH targets are usually a single letter; normalize to uppercase.""" - letter = extract_choice(str(target)) + """Normalize a BBH gold target. + + Two shapes handled: + - MCQ letter (A, B, C, ...) → uppercase letter + - Yes/No (causal_judgement and friends) → "YES" / "NO" + """ + t = str(target).strip() + if t.lower() in ("yes", "no"): + return t.upper() + letter = extract_choice(t) if letter is None: raise ValueError(f"Could not parse gold target: {target!r}") return letter @@ -120,12 +151,66 @@ def extract_choice(text: str) -> str | None: return None +def extract_yes_no(text: str) -> str | None: + """Parse Yes / No from model output. Returns "YES", "NO", or None. + + Mirrors the layered approach of `extract_choice`: explicit "Answer: Yes/No" + first, then "the answer is yes/no", then bare yes/no at the end. + """ + if not text or not text.strip(): + return None + + patterns = [ + r"(?im)^\s*answer\s*:\s*(yes|no)\s*\.?\s*$", + r"(?im)^\s*answer\s*:\s*(yes|no)\b", + r"(?i)\b(?:the\s+)?answer\s+is\s+(yes|no)\b", + ] + for pat in patterns: + matches = re.findall(pat, text) + if matches: + return matches[-1].upper() + + # Fallback: yes/no in the last non-empty line, then anywhere in last 5 lines. + lines = [ln.strip() for ln in text.strip().splitlines() if ln.strip()] + if lines: + m = re.search(r"\b(yes|no)\b", lines[-1], re.IGNORECASE) + if m: + return m.group(1).upper() + tail = "\n".join(lines[-5:]) + m = re.search(r"\b(yes|no)\b\s*\.?\s*$", tail, re.IGNORECASE) + if m: + return m.group(1).upper() + return None + + +def extract_answer(text: str) -> str | None: + """Unified extractor: try Yes/No first, then MCQ letter. + + Yes/No has to go first because extract_choice's "Answer: X" pattern + (lacking a $ anchor) would match the "Y" in "Answer: Yes" before + extract_yes_no got a look. For an actual letter answer like "Answer: B", + extract_yes_no returns None and the call falls through to extract_choice. + + Use this when you don't know the task shape (e.g. cross-task lexical + agreement). For scoring, prefer the gold-aware path in score_bbh. + """ + return extract_yes_no(text) or extract_choice(text) + + def score_bbh(model_output: str, task: dict) -> tuple[bool, str]: - """Grade a BBH response against task['gold'].""" - pred = extract_choice(model_output) + """Grade a BBH response against task['gold']. + + Routes extraction by gold shape: letter golds get extract_choice, + YES/NO golds get extract_yes_no. Avoids the false-positive where a + Yes/No question's reasoning happens to contain a stray "(A)" string. + """ + gold = task["gold"] + if gold in ("YES", "NO"): + pred = extract_yes_no(model_output) + else: + pred = extract_choice(model_output) if pred is None: return False, "unparseable" - gold = task["gold"] if pred == gold: return True, "passed" return False, f"expected {gold} got {pred}" @@ -133,15 +218,16 @@ def score_bbh(model_output: str, task: dict) -> tuple[bool, str]: def _row_to_task(subtask: str, index: int, row: dict) -> dict: gold = normalize_gold(row["target"]) + choices = row.get("choices") # may be absent for Yes/No subtasks return { "task_id": f"bbh/{subtask}/{index}", - "prompt": format_prompt(row["question"], row["choices"]), + "prompt": format_prompt(row["question"], choices), "gold": gold, "benchmark": "bbh", "subtask": subtask, # Echo arms use task["prompt"]; keep raw fields for debugging. "question": row["question"], - "choices": row["choices"], + "choices": choices, } diff --git a/experiment/benchmarks/bbh_arms.py b/experiment/benchmarks/bbh_arms.py index 23eb40c..37ab3cc 100644 --- a/experiment/benchmarks/bbh_arms.py +++ b/experiment/benchmarks/bbh_arms.py @@ -10,7 +10,7 @@ from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.runnables import RunnableLambda, RunnableParallel -from benchmarks.bbh import extract_choice, score_bbh +from benchmarks.bbh import extract_answer, score_bbh from chat_claude_code import ChatClaudeCode from run_pilot import SMALL_JUDGE_BASE_URL, SMALL_JUDGE_MODEL, _HAS_OLLAMA @@ -55,8 +55,13 @@ def _normalize_answer_text(text: str) -> str: def lexical_agree(a: str, b: str) -> bool: - """Same final choice letter, or identical normalized text.""" - ca, cb = extract_choice(a), extract_choice(b) + """Same final answer (letter OR Yes/No), or identical normalized text. + + Uses the unified extract_answer so Yes/No-shaped subtasks + (causal_judgement, web_of_lies, navigate, sports_understanding) get the + same letter-equality fast path as MCQ subtasks. + """ + ca, cb = extract_answer(a), extract_answer(b) if ca is not None and cb is not None: return ca == cb return _normalize_answer_text(a) == _normalize_answer_text(b) diff --git a/experiment/tests/test_bbh_scoring.py b/experiment/tests/test_bbh_scoring.py index fb524af..fd61086 100644 --- a/experiment/tests/test_bbh_scoring.py +++ b/experiment/tests/test_bbh_scoring.py @@ -8,7 +8,14 @@ sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) -from benchmarks.bbh import extract_choice, format_prompt, normalize_gold, score_bbh +from benchmarks.bbh import ( + extract_answer, + extract_choice, + extract_yes_no, + format_prompt, + normalize_gold, + score_bbh, +) class TestExtractChoice(unittest.TestCase): @@ -69,6 +76,62 @@ def test_includes_choices(self) -> None: self.assertIn("Ann", p) self.assertIn("Answer: X", p) + def test_yes_no_prompt_when_no_choices(self) -> None: + p = format_prompt("Did X cause Y?", None) + self.assertIn("Did X cause Y?", p) + self.assertIn("Answer: Yes", p) + self.assertIn("Answer: No", p) + self.assertNotIn("Answer: X", p) + + def test_yes_no_prompt_when_empty_choices(self) -> None: + p = format_prompt("Did X cause Y?", {}) + self.assertIn("Answer: Yes", p) + + +class TestYesNoSupport(unittest.TestCase): + def test_extract_yes_no_explicit(self) -> None: + self.assertEqual(extract_yes_no("Reasoning...\nAnswer: Yes"), "YES") + self.assertEqual(extract_yes_no("Reasoning...\nAnswer: No"), "NO") + + def test_extract_yes_no_phrase(self) -> None: + self.assertEqual(extract_yes_no("So the answer is yes"), "YES") + self.assertEqual(extract_yes_no("Therefore the answer is no"), "NO") + + def test_extract_yes_no_last_line(self) -> None: + self.assertEqual(extract_yes_no("Long reasoning ...\n...\nYes."), "YES") + + def test_extract_yes_no_unparseable(self) -> None: + self.assertIsNone(extract_yes_no("Maybe, but possibly")) + self.assertIsNone(extract_yes_no("")) + + def test_normalize_gold_yes_no(self) -> None: + self.assertEqual(normalize_gold("Yes"), "YES") + self.assertEqual(normalize_gold("No"), "NO") + self.assertEqual(normalize_gold("yes"), "YES") + + def test_normalize_gold_letter_still_works(self) -> None: + self.assertEqual(normalize_gold("A"), "A") + self.assertEqual(normalize_gold("(C)"), "C") + + def test_score_yes_no_task(self) -> None: + task = {"task_id": "t", "prompt": "p", "gold": "YES"} + ok, _ = score_bbh("Answer: Yes", task) + self.assertTrue(ok) + ok, detail = score_bbh("Answer: No", task) + self.assertFalse(ok) + self.assertIn("expected YES got NO", detail) + + def test_score_yes_no_ignores_stray_mcq_letters(self) -> None: + # A Yes/No question whose reasoning mentions "(A)" must not be + # mis-scored as an MCQ answer. + task = {"task_id": "t", "prompt": "p", "gold": "YES"} + ok, _ = score_bbh("Reasoning mentions (A) in passing.\nAnswer: Yes", task) + self.assertTrue(ok) + + def test_extract_answer_unified(self) -> None: + self.assertEqual(extract_answer("Answer: B"), "B") + self.assertEqual(extract_answer("Answer: Yes"), "YES") + if __name__ == "__main__": unittest.main()