Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.12
3.13
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ name = "complexity-aware-fine-tuning"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
requires-python = ">=3.13"
dependencies = [
"accelerate==1.3.0",
"datasets>=3.6.0",
"ipywidgets>=8.1.7",
"loguru>=0.7.3",
"matplotlib>=3.10.1",
"mistral-common>=1.9.1",
"mistralai>=1.6.0",
Expand All @@ -20,7 +21,7 @@ dependencies = [
"pydraconf>=0.1.0",
"scikit-learn>=1.6.1",
"seaborn>=0.13.2",
"sentencepiece>=0.2.0",
"sentencepiece>=0.2.1",
"torch>=2.6.0",
"transformers==4.52.3",
]
Expand Down
169 changes: 169 additions & 0 deletions src/core/training/base_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import json
import subprocess
from pathlib import Path
from typing import Any

from pydantic import BaseModel
from pydraconf import PydraConfig
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForTokenClassification,
PreTrainedTokenizer,
Seq2SeqTrainingArguments,
)
from transformers.trainer_seq2seq import Seq2SeqTrainer

from core.datasets.abstract_dataset_adapter import AbstractDatasetAdapter
from core.training.callbacks.save_by_schedule import SaveByScheduleCallback
from core.utils.last_checkpoint_dir import get_last_checkpoint_dir
from core.utils.logger import logger
from core.utils.seed import set_seed


class BaseTrainingArgs(BaseModel):
num_train_epochs: int
effective_train_batch_size: int = 256
per_device_train_batch_size: int

# Sane defaults for SFT fine-tuning
learning_rate: float = 2e-5
lr_scheduler_type: str = "cosine"
warmup_ratio: float = 0.03
weight_decay: float = 0.1
max_grad_norm: float = 1.0
optim: str = "adamw_torch"
gradient_checkpointing: bool = True
bf16: bool = True
report_to: str = "none"
seed: int = 42
data_seed: int = 42
torch_compile: bool = True
save_strategy: str = "epoch"
logging_steps: int = 10
logging_first_step: bool = True


class BaseTrainerConfig[TTrainingArgs: BaseTrainingArgs = BaseTrainingArgs](PydraConfig):
out_path: str
model_id: str
train_dataset: AbstractDatasetAdapter
training_args: TTrainingArgs
save_schedule: list[int] | None = None


class BaseTrainer[TConfig: BaseTrainerConfig[Any] = BaseTrainerConfig]:
def __init__(self, config: TConfig, tokenizer: PreTrainedTokenizer | None = None):
self.config = config
self._tokenizer: PreTrainedTokenizer | None = tokenizer

def train(self):
if not self._directory_is_empty(self.config.out_path, self.config.training_args.num_train_epochs):
logger.error("BaseTrainerConfig.train -> out_path not empty", self.config.out_path)
return None

set_seed()

logger.info(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)

train_ds = self._prepare_data()
self._run_training(train_ds)

return get_last_checkpoint_dir(self.config.out_path)

@property
def tokenizer(self):
if not self._tokenizer:
self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_id)

assert isinstance(self._tokenizer, PreTrainedTokenizer), (
"Tokenizer must be a PreTrainedTokenizer, but got {}".format(type(self._tokenizer))
)

if self._tokenizer.pad_token is None:
logger.warning("Tokenizer has no pad token, setting it to eos token")
self._tokenizer.pad_token = self._tokenizer.eos_token

return self._tokenizer

@property
def model(self):
if not self._model:
self._model = AutoModelForCausalLM.from_pretrained(self.config.model_id)

assert self._model is not None, "Model should be initialized"
return self._model

@property
def data_collator(self):
return DataCollatorForTokenClassification(
tokenizer=self.tokenizer, padding=True, pad_to_multiple_of=8, return_tensors="pt"
)

@property
def training_args(self):
return Seq2SeqTrainingArguments(
**self.config.training_args.model_dump(),
**self._batch_size_config(
self.config.training_args.effective_train_batch_size,
self.config.training_args.per_device_train_batch_size,
),
output_dir=self.config.out_path,
)

def _prepare_data(self):
train_ds = self.config.train_dataset.process_dataset()
logger.info("Dataset samples")
logger.info("Train")
logger.info(f"Input: {self.tokenizer.decode(train_ds[0]['input_ids'])}")
logger.info(f"Labels: {self.tokenizer.decode(train_ds[0]['labels'])}")

return train_ds

def _run_training(self, train_ds):
trainer = Seq2SeqTrainer(
model=self.model,
args=self.training_args,
train_dataset=train_ds,
data_collator=self.data_collator,
processing_class=self.tokenizer,
)

if self.config.save_schedule is not None:
trainer.add_callback(SaveByScheduleCallback(schedule=self.config.save_schedule))

trainer.train(resume_from_checkpoint=True)

def _directory_is_empty(self, directory: str, expected_epochs: int) -> bool:
p = Path(directory)
if not p.exists():
return True
if not p.is_dir():
raise Exception("Not a directory!")

checkpoint_dirs = list(p.glob("checkpoint-*"))
if not checkpoint_dirs:
return True

checkpoint_dirs.sort(key=lambda x: int(x.name.split("-")[1]))
last_checkpoint = checkpoint_dirs[-1] if checkpoint_dirs else None

if last_checkpoint:
state_file = last_checkpoint / "trainer_state.json"
if state_file.exists():
with open(state_file, "r") as f:
state = json.load(f)
if int(state.get("epoch", 0)) == expected_epochs:
return False

return True

def _batch_size_config(self, effective_batch_size: int, per_device_train_batch_size: int):
gradient_accumulation_steps = effective_batch_size // per_device_train_batch_size
assert effective_batch_size % per_device_train_batch_size == 0, (
f"Effective batch size {effective_batch_size} is not divisible by per device batch size {per_device_train_batch_size}"
)
return {
"per_device_train_batch_size": per_device_train_batch_size,
"gradient_accumulation_steps": gradient_accumulation_steps,
}
58 changes: 0 additions & 58 deletions src/core/training/callbacks/eval_every_n_epoch.py

This file was deleted.

61 changes: 0 additions & 61 deletions src/core/training/callbacks/save_and_log_weights.py

This file was deleted.

18 changes: 18 additions & 0 deletions src/core/training/callbacks/save_by_schedule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import override

from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments


class SaveByScheduleCallback(TrainerCallback):
def __init__(self, schedule: list[int]):
self.schedule = schedule

@override
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
assert state.epoch is not None

epoch_num = int(state.epoch)

control.should_save = False
if epoch_num in self.schedule:
control.should_save = True
16 changes: 0 additions & 16 deletions src/core/training/callbacks/save_every_n_epoch.py

This file was deleted.

Loading