diff --git a/notebooks/train-hippofloop.ipynb b/notebooks/train-hippofloop.ipynb index 0c56293..b4c7de2 100644 --- a/notebooks/train-hippofloop.ipynb +++ b/notebooks/train-hippofloop.ipynb @@ -227,43 +227,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "from transformers import TrainingArguments\n", - "from trl import SFTTrainer\n", - "\n", - "OUTPUT_DIR = \"checkpoints/qwen25-3b-hippofloop\"\n", - "\n", - "training_args = TrainingArguments(\n", - " output_dir=OUTPUT_DIR,\n", - " num_train_epochs=3,\n", - " per_device_train_batch_size=1,\n", - " gradient_accumulation_steps=16,\n", - " learning_rate=2e-4,\n", - " lr_scheduler_type=\"cosine\",\n", - " warmup_ratio=0.03,\n", - " weight_decay=0.01,\n", - " bf16=False,\n", - " fp16=True,\n", - " eval_strategy=\"epoch\",\n", - " save_strategy=\"epoch\",\n", - " load_best_model_at_end=True,\n", - " metric_for_best_model=\"eval_loss\",\n", - " logging_steps=10,\n", - " seed=SEED,\n", - ")\n", - "\n", - "sft_trainer = SFTTrainer(\n", - " model=model,\n", - " tokenizer=tokenizer,\n", - " train_dataset=train_dataset,\n", - " eval_dataset=val_dataset,\n", - " args=training_args,\n", - ")\n", - "\n", - "print(f\"Training {len(train_dataset)} examples for {training_args.num_train_epochs} epochs...\")\n", - "print(f\"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}\")\n", - "sft_trainer.train()" - ] + "source": "from transformers import TrainingArguments\nfrom trl import SFTTrainer\n\nOUTPUT_DIR = \"checkpoints/qwen25-3b-hippofloop\"\n\ntraining_args = TrainingArguments(\n output_dir=OUTPUT_DIR,\n num_train_epochs=3,\n per_device_train_batch_size=1,\n gradient_accumulation_steps=16,\n learning_rate=2e-4,\n lr_scheduler_type=\"cosine\",\n warmup_steps=50,\n weight_decay=0.01,\n bf16=False,\n fp16=True,\n eval_strategy=\"epoch\",\n save_strategy=\"epoch\",\n load_best_model_at_end=True,\n metric_for_best_model=\"eval_loss\",\n logging_steps=10,\n seed=SEED,\n)\n\ndef formatting_func(examples):\n texts = []\n for msgs in examples[\"messages\"]:\n texts.append(tokenizer.apply_chat_template(msgs, tokenize=False))\n return texts\n\nsft_trainer = SFTTrainer(\n model=model,\n tokenizer=tokenizer,\n train_dataset=train_dataset,\n eval_dataset=val_dataset,\n args=training_args,\n formatting_func=formatting_func,\n)\n\nprint(f\"Training {len(train_dataset)} examples for {training_args.num_train_epochs} epochs...\")\nprint(f\"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}\")\nsft_trainer.train()" }, { "cell_type": "code", diff --git a/src/hippofloop/training/trainer.py b/src/hippofloop/training/trainer.py index 69a948e..4011d3e 100644 --- a/src/hippofloop/training/trainer.py +++ b/src/hippofloop/training/trainer.py @@ -76,7 +76,7 @@ def train( gradient_accumulation_steps=self._config.gradient_accumulation_steps, learning_rate=self._config.learning_rate, lr_scheduler_type=self._config.lr_scheduler, - warmup_ratio=self._config.warmup_ratio, + warmup_steps=50, weight_decay=self._config.weight_decay, bf16=self._config.bf16, fp16=self._config.fp16, @@ -88,12 +88,19 @@ def train( seed=self._config.seed, ) + def formatting_func(examples: dict) -> list[str]: + return [ + tokenizer.apply_chat_template(msgs, tokenize=False) + for msgs in examples["messages"] + ] + trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=train_dataset, eval_dataset=val_dataset, args=training_args, + formatting_func=formatting_func, ) logger.info("Starting training (%d epochs)", self._config.epochs)