-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
54 lines (42 loc) · 1.76 KB
/
train.py
File metadata and controls
54 lines (42 loc) · 1.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import argparse
from omegaconf import OmegaConf
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import set_seed
from config import EpsilonDPOConfig
from trainer import EpsilonDPOTrainer
def main(config):
set_seed(config.training_args.seed)
model = AutoModelForCausalLM.from_pretrained(**config.model, torch_dtype=torch.bfloat16)
ref_model = AutoModelForCausalLM.from_pretrained(**config.model, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(config.model.pretrained_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
columns = ['chosen', 'rejected']
dataset = load_dataset(config.dataset.name)
train_dataset = dataset[config.dataset.train_split].select_columns(columns)
if config.dataset.eval_split:
eval_dataset = dataset[config.dataset.eval_split].select_columns(columns)
else:
eval_dataset = None
training_args = EpsilonDPOConfig(**config.training_args)
trainer = EpsilonDPOTrainer(model=model,
ref_model=ref_model,
args=training_args,
processing_class=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset)
trainer.train()
trainer.save_model(config.training_args.output_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='ε-DPO configuration')
parser.add_argument(
'--config',
type=str,
required=True,
help='path to configuration',
)
args = parser.parse_args()
config = OmegaConf.load(args.config)
main(config)