From 6b50712b2ac332f6eceb0b5817f614f4185626ac Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Wed, 18 Mar 2026 00:11:39 +0000 Subject: [PATCH] Refactor numeric extraction logic into a shared utility function Co-authored-by: dhanush342 <187305764+dhanush342@users.noreply.github.com> --- .gitignore | 2 ++ evaluate_gsm8k.py | 36 +++--------------------------------- tests/test_utils.py | 34 ++++++++++++++++++++++++++++++++++ utils.py | 23 +++++++++++++++++++++++ web/streamlit_dashboard.py | 22 ++++------------------ 5 files changed, 66 insertions(+), 51 deletions(-) create mode 100644 .gitignore create mode 100644 tests/test_utils.py create mode 100644 utils.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7a60b85 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +*.pyc diff --git a/evaluate_gsm8k.py b/evaluate_gsm8k.py index 3826a32..fccbcd7 100644 --- a/evaluate_gsm8k.py +++ b/evaluate_gsm8k.py @@ -1,51 +1,21 @@ import argparse import csv -import re import math -from fractions import Fraction from datasets import load_dataset import inference - - -NUM_RE = re.compile(r"-?\d+\/?\d*\.?\d*") - - -def parse_numeric(s: str): - """Try to extract a numeric value from a string. Returns float or None.""" - if not s or not isinstance(s, str): - return None - - # Try fraction first - frac_match = re.search(r"(\d+)/(\d+)", s) - if frac_match: - try: - return float(Fraction(int(frac_match.group(1)), int(frac_match.group(2)))) - except Exception: - pass - - # Find decimals or integers - nums = re.findall(r"-?\d+\.?\d*", s) - if not nums: - return None - - # Prefer last numeric token (often final answer) - token = nums[-1] - try: - return float(token) - except Exception: - return None +from utils import extract_numeric def normalize_reference_answer(ans: str): # GSM8K references sometimes include explanation; extract numeric - return parse_numeric(ans) + return extract_numeric(ans) def extract_predicted_answer(text: str): # heuristic: look for last numeric occurrence in model output - return parse_numeric(text) + return extract_numeric(text) def main(): diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..182a9cd --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,34 @@ +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from utils import extract_numeric + +def test_extract_numeric(): + # Null and wrong type inputs + assert extract_numeric(None) is None + assert extract_numeric(123) == 123.0 + + # Simple numbers + assert extract_numeric("42") == 42.0 + assert extract_numeric("-10") == -10.0 + assert extract_numeric("3.14") == 3.14 + assert extract_numeric("-2.5") == -2.5 + + # Fractions + assert extract_numeric("1/2") == 0.5 + assert extract_numeric("-3/4") == -0.75 + + # Text with numbers + assert extract_numeric("The answer is 42") == 42.0 + assert extract_numeric("I have 3 apples and 2.5 oranges") == 2.5 + assert extract_numeric("Result: 5/2 units") == 2.5 + + # Complex or weird formats + assert extract_numeric("No numbers here!") is None + assert extract_numeric("Just a dot . ") is None + + print("All tests passed!") + +if __name__ == "__main__": + test_extract_numeric() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..9cc56fb --- /dev/null +++ b/utils.py @@ -0,0 +1,23 @@ +import re +from fractions import Fraction + +def extract_numeric(text): + """Try to extract a numeric value from a string. Returns float or None.""" + if text is None: + return None + + if not isinstance(text, str): + text = str(text) + + # find fractions like 3/4 or decimals/integers + matches = re.findall(r"-?\d+(?:/\d+)?(?:\.\d+)?", text) + if not matches: + return None + + last = matches[-1] + try: + if "/" in last: + return float(Fraction(last)) + return float(last) + except Exception: + return None diff --git a/web/streamlit_dashboard.py b/web/streamlit_dashboard.py index c76da3a..71e7712 100644 --- a/web/streamlit_dashboard.py +++ b/web/streamlit_dashboard.py @@ -3,28 +3,14 @@ import json import csv import pandas as pd -import re -from fractions import Fraction +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils import extract_numeric from inference import generate_solution -def extract_numeric(text): - if text is None: - return None - # find fractions like 3/4 or decimals/integers - matches = re.findall(r"-?\d+(?:/\d+)?(?:\.\d+)?", text) - if not matches: - return None - last = matches[-1] - try: - if "/" in last: - return float(Fraction(last)) - return float(last) - except Exception: - return None - - def solve_and_display(problem, cot, temperature, top_p, max_new_tokens, base_model, adapter_path): with st.spinner("Generating solution..."): out = generate_solution(problem, cot=cot, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, base_model=base_model, adapter_path=adapter_path)