Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.12
3.13
20 changes: 17 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ name = "complexity-aware-fine-tuning"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
requires-python = ">=3.13,<3.14"
dependencies = [
"accelerate==1.3.0",
"datasets>=3.6.0",
"ipywidgets>=8.1.7",
"loguru>=0.7.3",
"matplotlib>=3.10.1",
"mistralai>=1.6.0",
"numpy>=2.2.4",
Expand All @@ -16,11 +17,24 @@ dependencies = [
"peft>=0.17.1",
"protobuf>=6.30.2",
"pyarrow>=19.0.1",
"pydraconf>=0.1.0",
"scikit-learn>=1.6.1",
"seaborn>=0.13.2",
"sentencepiece>=0.2.0",
"sentencepiece>=0.2.1",
"torch>=2.6.0",
"transformers==4.52.3",
"transformers>=4.52.3",
"vllm",
]

[tool.uv]
environments = [
"platform_machine == 'arm64' and sys_platform == 'darwin'",
"platform_machine == 'x86_64' and sys_platform == 'linux'",
]

[tool.uv.sources]
vllm = [
{ git = "https://github.com/vllm-project/vllm.git", tag = "v0.11.0", marker = "sys_platform == 'darwin'" },
]

[build-system]
Expand Down
10 changes: 5 additions & 5 deletions src/core/datasets/base_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,26 @@ class TokenizedRow:


class BaseDatasetAdapter(ABC):
def __init__(self, df_path: str, tokenizer: PreTrainedTokenizer):
def __init__(self, id: str, df_path: str):
self.id: str = id
self.df_path: str = df_path
self.tokenizer: PreTrainedTokenizer = tokenizer

@abstractmethod
def process_row(self, row: pd.Series) -> TokenizedRow: ...
def process_row(self, row: pd.Series, tokenizer: PreTrainedTokenizer) -> TokenizedRow: ...

def _load_df(self) -> pd.DataFrame:
df = pd.read_parquet(
self.df_path,
)
return df

def process_dataset(self):
def process_dataset(self, tokenizer: PreTrainedTokenizer) -> Dataset:
df = self._load_df()

dataset = Dataset.from_pandas(df)

processed_ds = dataset.map(
lambda row: asdict(self.process_row(row)),
lambda row: asdict(self.process_row(row, tokenizer)),
num_proc=4,
remove_columns=dataset.column_names,
)
Expand Down
7 changes: 4 additions & 3 deletions src/core/datasets/causal_dataset_adapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import abstractmethod

import pandas as pd
from transformers import PreTrainedTokenizer

from core.datasets.base_dataset_adapter import BaseDatasetAdapter, TokenizedRow

Expand All @@ -18,20 +19,20 @@ def assistant_response(self, row: pd.Series) -> str: ...
@abstractmethod
def row_id(self, row: pd.Series) -> str: ...

def process_row(self, row: pd.Series) -> TokenizedRow:
def process_row(self, row: pd.Series, tokenizer: PreTrainedTokenizer) -> TokenizedRow:
input_messages = [
{"role": "system", "content": self.system_prompt(row)},
{"role": "user", "content": self.user_prompt(row)},
]

full = self.tokenizer.apply_chat_template(
full = tokenizer.apply_chat_template(
input_messages + [{"role": "assistant", "content": self.assistant_response(row)}],
tokenize=True,
add_generation_prompt=False,
return_dict=True,
)

prefix = self.tokenizer.apply_chat_template(
prefix = tokenizer.apply_chat_template(
input_messages,
tokenize=True,
add_generation_prompt=True,
Expand Down
Empty file added src/core/evaluation/__init__.py
Empty file.
53 changes: 53 additions & 0 deletions src/core/evaluation/vllm_cot_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import override

from core.evaluation.vllm_evaluator import VLLMEvaluator, VLLMEvaluatorConfig
from core.prompts.thinking_markers import THINKING_END


class VLLMCoTEvaluatorConfig(VLLMEvaluatorConfig):
max_tokens: int = 4096


class VLLMCoTEvaluator(VLLMEvaluator):
@override
def _compute_metrics(
self,
outputs: list,
golds: list[str],
question_ids: list[str],
) -> dict:
correct = 0
total = len(outputs)
incorrect: list[dict] = []

for output, gold, qid in zip(outputs, golds, question_ids):
generated_text = output.outputs[0].text
predicted = self._extract_answer(generated_text)
gold_normalized = gold.strip().lower()

if predicted == gold_normalized:
correct += 1
else:
incorrect.append({
"question_id": qid,
"gold": gold,
"predicted": predicted,
"full_output": generated_text,
})

return {
"accuracy": correct / total if total > 0 else 0.0,
"total": total,
"correct": correct,
"incorrect": incorrect,
}

@staticmethod
def _extract_answer(text: str) -> str:
end_idx = text.find(THINKING_END)
if end_idx == -1:
return ""
after_think = text[end_idx + len(THINKING_END):].strip()
if not after_think:
return ""
return after_think[0].lower()
193 changes: 193 additions & 0 deletions src/core/evaluation/vllm_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import json
from pathlib import Path

import pandas as pd
from pydraconf import PydraConfig
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

from core.datasets.causal_dataset_adapter import CausalDatasetAdapter
from core.utils.logger import logger
from core.utils.seed import set_seed


class VLLMEvaluatorConfig(PydraConfig):
model_id: str
checkpoint_dirs: list[str]
eval_datasets: dict[str, CausalDatasetAdapter]
out_path: str
max_tokens: int = 1
temperature: float = 0.0
tensor_parallel_size: int = 1
seed: int = 42


class VLLMEvaluator:
def __init__(self, config: VLLMEvaluatorConfig):
self.config = config
self._llm: LLM | None = None

@property
def llm(self) -> LLM:
if self._llm is None:
first_ckpt = Path(self.config.checkpoint_dirs[0])
is_lora = self._is_lora(first_ckpt)

kwargs = {}
if is_lora:
kwargs["enable_lora"] = True
kwargs["max_lora_rank"] = self._read_lora_rank(first_ckpt)

self._llm = LLM(
model=self.config.model_id,
tensor_parallel_size=self.config.tensor_parallel_size,
seed=self.config.seed,
**kwargs,
)
return self._llm

def evaluate(self) -> dict[str, dict]:
set_seed()

sampling_params = SamplingParams(
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
)

prompts_by_dataset = self._build_all_prompts()

results: dict[str, dict] = {}

for ckpt_idx, ckpt_path in enumerate(self.config.checkpoint_dirs):
ckpt_dir = Path(ckpt_path)
ckpt_name = ckpt_dir.name
logger.info("Evaluating checkpoint: {}", ckpt_name)

lora_request = None
if self._is_lora(ckpt_dir):
lora_request = LoRARequest(
lora_name=ckpt_name,
lora_int_id=ckpt_idx + 1,
lora_local_path=ckpt_path,
)

dataset_results: dict[str, dict] = {}
combined_correct = 0
combined_total = 0

for ds_name, (messages_list, golds, question_ids) in prompts_by_dataset.items():
outputs = self.llm.chat(
messages=messages_list,
sampling_params=sampling_params,
lora_request=lora_request,
)

metrics = self._compute_metrics(outputs, golds, question_ids)
dataset_results[ds_name] = metrics
combined_correct += metrics["correct"]
combined_total += metrics["total"]

logger.info("{} — {}: accuracy={:.4f}", ckpt_name, ds_name, metrics["accuracy"])

combined_accuracy = combined_correct / combined_total if combined_total > 0 else 0.0
dataset_results["combined"] = {
"accuracy": combined_accuracy,
"total": combined_total,
"correct": combined_correct,
}

logger.info("{} — combined: accuracy={:.4f}", ckpt_name, combined_accuracy)

results[ckpt_name] = dataset_results
self._save_results(ckpt_name, dataset_results)

return results

def _build_all_prompts(self) -> dict[str, tuple[list[list[dict[str, str]]], list[str], list[str]]]:
prompts_by_dataset: dict[str, tuple[list[list[dict[str, str]]], list[str], list[str]]] = {}

for ds_name, adapter in self.config.eval_datasets.items():
df = adapter._load_df()
messages_list: list[list[dict[str, str]]] = []
golds: list[str] = []
question_ids: list[str] = []

for _, row in df.iterrows():
messages = [
{"role": "system", "content": adapter.system_prompt(row)},
{"role": "user", "content": adapter.user_prompt(row)},
]
messages_list.append(messages)
golds.append(adapter.assistant_response(row))
question_ids.append(adapter.row_id(row))

prompts_by_dataset[ds_name] = (messages_list, golds, question_ids)

return prompts_by_dataset

def _compute_metrics(
self,
outputs: list,
golds: list[str],
question_ids: list[str],
) -> dict:
correct = 0
total = len(outputs)
incorrect: list[dict] = []

for output, gold, qid in zip(outputs, golds, question_ids):
predicted = output.outputs[0].text.strip().lower()
gold_normalized = gold.strip().lower()

if predicted == gold_normalized:
correct += 1
else:
incorrect.append(
{
"question_id": qid,
"gold": gold,
"predicted": predicted,
}
)

return {
"accuracy": correct / total if total > 0 else 0.0,
"total": total,
"correct": correct,
"incorrect": incorrect,
}

def _save_results(self, ckpt_name: str, results_by_dataset: dict[str, dict]) -> None:
out_dir = Path(self.config.out_path) / ckpt_name
out_dir.mkdir(parents=True, exist_ok=True)

metrics_summary = {}
all_incorrect: list[dict] = []

for ds_name, result in results_by_dataset.items():
metrics_summary[ds_name] = {k: v for k, v in result.items() if k != "incorrect"}
for item in result.get("incorrect", []):
item_with_ds = {**item, "dataset": ds_name}
all_incorrect.append(item_with_ds)

with open(out_dir / "metrics.json", "w") as f:
json.dump(metrics_summary, f, indent=2)

if all_incorrect:
pd.DataFrame(all_incorrect).to_csv(
out_dir / "incorrect_answers.tsv",
sep="\t",
index=False,
)

logger.info("Results saved to {}", out_dir)

@staticmethod
def _is_lora(checkpoint_dir: Path) -> bool:
return (checkpoint_dir / "adapter_config.json").exists()

@staticmethod
def _read_lora_rank(checkpoint_dir: Path) -> int:
with open(checkpoint_dir / "adapter_config.json") as f:
config = json.load(f)
return config["r"]
Loading