Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 98 additions & 12 deletions experiment/benchmarks/bbh.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,51 @@ def _format_choices(choices: dict[str, Any]) -> str:
return "\n".join(lines)


def format_prompt(question: str, choices: dict[str, Any]) -> str:
"""Build the model-facing prompt for a BBH multiple-choice item."""
def _has_choices(choices: Any) -> bool:
"""True if the dataset row actually exposed a multiple-choice list."""
if not choices:
return False
if not isinstance(choices, dict):
return False
return bool(choices.get("text") or choices.get("label"))


def format_prompt(question: str, choices: dict[str, Any] | None) -> str:
"""Build the model-facing prompt for a BBH item.

Two shapes:
- MCQ (has `choices`) → "Answer: X" letter convention.
- Yes/No (no `choices`, e.g. causal_judgement, web_of_lies, navigate)
→ "Answer: Yes" or "Answer: No".
"""
if _has_choices(choices):
return (
f"{question.strip()}\n\n"
f"Choices:\n{_format_choices(choices)}\n\n"
"Reply with your reasoning, then end with exactly one line:\n"
"Answer: X\n"
"where X is the letter of the correct choice (A, B, C, ...)."
)
return (
f"{question.strip()}\n\n"
f"Choices:\n{_format_choices(choices)}\n\n"
"Reply with your reasoning, then end with exactly one line:\n"
"Answer: X\n"
"where X is the letter of the correct choice (A, B, C, ...)."
"Answer: Yes\n"
"or\n"
"Answer: No"
)


def normalize_gold(target: str) -> str:
"""BBH targets are usually a single letter; normalize to uppercase."""
letter = extract_choice(str(target))
"""Normalize a BBH gold target.

Two shapes handled:
- MCQ letter (A, B, C, ...) → uppercase letter
- Yes/No (causal_judgement and friends) → "YES" / "NO"
"""
t = str(target).strip()
if t.lower() in ("yes", "no"):
return t.upper()
letter = extract_choice(t)
if letter is None:
raise ValueError(f"Could not parse gold target: {target!r}")
return letter
Expand Down Expand Up @@ -120,28 +151,83 @@ def extract_choice(text: str) -> str | None:
return None


def extract_yes_no(text: str) -> str | None:
"""Parse Yes / No from model output. Returns "YES", "NO", or None.

Mirrors the layered approach of `extract_choice`: explicit "Answer: Yes/No"
first, then "the answer is yes/no", then bare yes/no at the end.
"""
if not text or not text.strip():
return None

patterns = [
r"(?im)^\s*answer\s*:\s*(yes|no)\s*\.?\s*$",
r"(?im)^\s*answer\s*:\s*(yes|no)\b",
r"(?i)\b(?:the\s+)?answer\s+is\s+(yes|no)\b",
]
for pat in patterns:
matches = re.findall(pat, text)
if matches:
return matches[-1].upper()

# Fallback: yes/no in the last non-empty line, then anywhere in last 5 lines.
lines = [ln.strip() for ln in text.strip().splitlines() if ln.strip()]
if lines:
m = re.search(r"\b(yes|no)\b", lines[-1], re.IGNORECASE)
if m:
return m.group(1).upper()
tail = "\n".join(lines[-5:])
m = re.search(r"\b(yes|no)\b\s*\.?\s*$", tail, re.IGNORECASE)
if m:
return m.group(1).upper()
return None


def extract_answer(text: str) -> str | None:
"""Unified extractor: try Yes/No first, then MCQ letter.

Yes/No has to go first because extract_choice's "Answer: X" pattern
(lacking a $ anchor) would match the "Y" in "Answer: Yes" before
extract_yes_no got a look. For an actual letter answer like "Answer: B",
extract_yes_no returns None and the call falls through to extract_choice.

Use this when you don't know the task shape (e.g. cross-task lexical
agreement). For scoring, prefer the gold-aware path in score_bbh.
"""
return extract_yes_no(text) or extract_choice(text)


def score_bbh(model_output: str, task: dict) -> tuple[bool, str]:
"""Grade a BBH response against task['gold']."""
pred = extract_choice(model_output)
"""Grade a BBH response against task['gold'].

Routes extraction by gold shape: letter golds get extract_choice,
YES/NO golds get extract_yes_no. Avoids the false-positive where a
Yes/No question's reasoning happens to contain a stray "(A)" string.
"""
gold = task["gold"]
if gold in ("YES", "NO"):
pred = extract_yes_no(model_output)
else:
pred = extract_choice(model_output)
if pred is None:
return False, "unparseable"
gold = task["gold"]
if pred == gold:
return True, "passed"
return False, f"expected {gold} got {pred}"


def _row_to_task(subtask: str, index: int, row: dict) -> dict:
gold = normalize_gold(row["target"])
choices = row.get("choices") # may be absent for Yes/No subtasks
return {
"task_id": f"bbh/{subtask}/{index}",
"prompt": format_prompt(row["question"], row["choices"]),
"prompt": format_prompt(row["question"], choices),
"gold": gold,
"benchmark": "bbh",
"subtask": subtask,
# Echo arms use task["prompt"]; keep raw fields for debugging.
"question": row["question"],
"choices": row["choices"],
"choices": choices,
}


Expand Down
11 changes: 8 additions & 3 deletions experiment/benchmarks/bbh_arms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableLambda, RunnableParallel

from benchmarks.bbh import extract_choice, score_bbh
from benchmarks.bbh import extract_answer, score_bbh
from chat_claude_code import ChatClaudeCode
from run_pilot import SMALL_JUDGE_BASE_URL, SMALL_JUDGE_MODEL, _HAS_OLLAMA

Expand Down Expand Up @@ -55,8 +55,13 @@ def _normalize_answer_text(text: str) -> str:


def lexical_agree(a: str, b: str) -> bool:
"""Same final choice letter, or identical normalized text."""
ca, cb = extract_choice(a), extract_choice(b)
"""Same final answer (letter OR Yes/No), or identical normalized text.

Uses the unified extract_answer so Yes/No-shaped subtasks
(causal_judgement, web_of_lies, navigate, sports_understanding) get the
same letter-equality fast path as MCQ subtasks.
"""
ca, cb = extract_answer(a), extract_answer(b)
if ca is not None and cb is not None:
return ca == cb
return _normalize_answer_text(a) == _normalize_answer_text(b)
Expand Down
65 changes: 64 additions & 1 deletion experiment/tests/test_bbh_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from benchmarks.bbh import extract_choice, format_prompt, normalize_gold, score_bbh
from benchmarks.bbh import (
extract_answer,
extract_choice,
extract_yes_no,
format_prompt,
normalize_gold,
score_bbh,
)


class TestExtractChoice(unittest.TestCase):
Expand Down Expand Up @@ -69,6 +76,62 @@ def test_includes_choices(self) -> None:
self.assertIn("Ann", p)
self.assertIn("Answer: X", p)

def test_yes_no_prompt_when_no_choices(self) -> None:
p = format_prompt("Did X cause Y?", None)
self.assertIn("Did X cause Y?", p)
self.assertIn("Answer: Yes", p)
self.assertIn("Answer: No", p)
self.assertNotIn("Answer: X", p)

def test_yes_no_prompt_when_empty_choices(self) -> None:
p = format_prompt("Did X cause Y?", {})
self.assertIn("Answer: Yes", p)


class TestYesNoSupport(unittest.TestCase):
def test_extract_yes_no_explicit(self) -> None:
self.assertEqual(extract_yes_no("Reasoning...\nAnswer: Yes"), "YES")
self.assertEqual(extract_yes_no("Reasoning...\nAnswer: No"), "NO")

def test_extract_yes_no_phrase(self) -> None:
self.assertEqual(extract_yes_no("So the answer is yes"), "YES")
self.assertEqual(extract_yes_no("Therefore the answer is no"), "NO")

def test_extract_yes_no_last_line(self) -> None:
self.assertEqual(extract_yes_no("Long reasoning ...\n...\nYes."), "YES")

def test_extract_yes_no_unparseable(self) -> None:
self.assertIsNone(extract_yes_no("Maybe, but possibly"))
self.assertIsNone(extract_yes_no(""))

def test_normalize_gold_yes_no(self) -> None:
self.assertEqual(normalize_gold("Yes"), "YES")
self.assertEqual(normalize_gold("No"), "NO")
self.assertEqual(normalize_gold("yes"), "YES")

def test_normalize_gold_letter_still_works(self) -> None:
self.assertEqual(normalize_gold("A"), "A")
self.assertEqual(normalize_gold("(C)"), "C")

def test_score_yes_no_task(self) -> None:
task = {"task_id": "t", "prompt": "p", "gold": "YES"}
ok, _ = score_bbh("Answer: Yes", task)
self.assertTrue(ok)
ok, detail = score_bbh("Answer: No", task)
self.assertFalse(ok)
self.assertIn("expected YES got NO", detail)

def test_score_yes_no_ignores_stray_mcq_letters(self) -> None:
# A Yes/No question whose reasoning mentions "(A)" must not be
# mis-scored as an MCQ answer.
task = {"task_id": "t", "prompt": "p", "gold": "YES"}
ok, _ = score_bbh("Reasoning mentions (A) in passing.\nAnswer: Yes", task)
self.assertTrue(ok)

def test_extract_answer_unified(self) -> None:
self.assertEqual(extract_answer("Answer: B"), "B")
self.assertEqual(extract_answer("Answer: Yes"), "YES")


if __name__ == "__main__":
unittest.main()