Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__pycache__/
*.pyc
36 changes: 3 additions & 33 deletions evaluate_gsm8k.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,21 @@
import argparse
import csv
import re
import math
from fractions import Fraction

from datasets import load_dataset

import inference


NUM_RE = re.compile(r"-?\d+\/?\d*\.?\d*")


def parse_numeric(s: str):
"""Try to extract a numeric value from a string. Returns float or None."""
if not s or not isinstance(s, str):
return None

# Try fraction first
frac_match = re.search(r"(\d+)/(\d+)", s)
if frac_match:
try:
return float(Fraction(int(frac_match.group(1)), int(frac_match.group(2))))
except Exception:
pass

# Find decimals or integers
nums = re.findall(r"-?\d+\.?\d*", s)
if not nums:
return None

# Prefer last numeric token (often final answer)
token = nums[-1]
try:
return float(token)
except Exception:
return None
from utils import extract_numeric


def normalize_reference_answer(ans: str):
# GSM8K references sometimes include explanation; extract numeric
return parse_numeric(ans)
return extract_numeric(ans)


def extract_predicted_answer(text: str):
# heuristic: look for last numeric occurrence in model output
return parse_numeric(text)
return extract_numeric(text)


def main():
Expand Down
34 changes: 34 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from utils import extract_numeric

def test_extract_numeric():
# Null and wrong type inputs
assert extract_numeric(None) is None
assert extract_numeric(123) == 123.0

# Simple numbers
assert extract_numeric("42") == 42.0
assert extract_numeric("-10") == -10.0
assert extract_numeric("3.14") == 3.14
assert extract_numeric("-2.5") == -2.5

# Fractions
assert extract_numeric("1/2") == 0.5
assert extract_numeric("-3/4") == -0.75

# Text with numbers
assert extract_numeric("The answer is 42") == 42.0
assert extract_numeric("I have 3 apples and 2.5 oranges") == 2.5
assert extract_numeric("Result: 5/2 units") == 2.5

# Complex or weird formats
assert extract_numeric("No numbers here!") is None
assert extract_numeric("Just a dot . ") is None

print("All tests passed!")

if __name__ == "__main__":
test_extract_numeric()
23 changes: 23 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import re
from fractions import Fraction

def extract_numeric(text):
"""Try to extract a numeric value from a string. Returns float or None."""
if text is None:
return None

if not isinstance(text, str):
text = str(text)

# find fractions like 3/4 or decimals/integers
matches = re.findall(r"-?\d+(?:/\d+)?(?:\.\d+)?", text)
if not matches:
return None

last = matches[-1]
try:
if "/" in last:
return float(Fraction(last))
return float(last)
except Exception:
return None
22 changes: 4 additions & 18 deletions web/streamlit_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,14 @@
import json
import csv
import pandas as pd
import re
from fractions import Fraction
import sys
import os

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import extract_numeric
from inference import generate_solution
Comment on lines +6 to 11


def extract_numeric(text):
if text is None:
return None
# find fractions like 3/4 or decimals/integers
matches = re.findall(r"-?\d+(?:/\d+)?(?:\.\d+)?", text)
if not matches:
return None
last = matches[-1]
try:
if "/" in last:
return float(Fraction(last))
return float(last)
except Exception:
return None


def solve_and_display(problem, cot, temperature, top_p, max_new_tokens, base_model, adapter_path):
with st.spinner("Generating solution..."):
out = generate_solution(problem, cot=cot, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, base_model=base_model, adapter_path=adapter_path)
Expand Down
Loading