forked from Multihuntr/gff
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
65 lines (50 loc) · 1.59 KB
/
train.py
File metadata and controls
65 lines (50 loc) · 1.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import argparse
import datetime
import sys
from pathlib import Path
from matplotlib import pyplot as plt
import torch
import yaml
import gff.constants
import gff.dataloaders
import gff.evaluation
import gff.models.creation
import gff.training
import gff.util
def parse_args(argv):
parser = argparse.ArgumentParser("Train a model")
parser.add_argument("config_path", type=Path)
parser.add_argument(
"--overwrite",
"-o",
type=gff.util.pair,
nargs="*",
default=[],
help="Overwrite config setting",
)
return parser.parse_args(argv)
def main(args):
# Load config file, and overwrite anything from cmdline arguments
with open(args.config_path) as f:
C = yaml.safe_load(f)
for k, v in args.overwrite:
C[k] = v
# Set seed for reproducibility
gff.util.seed_packages(C["seed"])
# Instantiate the model
model = gff.models.creation.create(C)
model.to(C["device"])
# Prepare data loaders
g = torch.Generator()
g.manual_seed(C["seed"])
train_dl, val_dl, test_dl = gff.dataloaders.create_all(C, g)
# Train the model
model_folder = Path("runs") / f'{C["name"]}_{datetime.datetime.now().isoformat()}'
model_folder.mkdir(parents=True, exist_ok=True)
with open(model_folder / "config.yml", "w") as f:
yaml.safe_dump(C, f)
gff.training.training_loop(C, model_folder, model, (train_dl, val_dl))
# Run inference on the test set with the model
gff.evaluation.model_inference(model_folder, model, test_dl)
if __name__ == "__main__":
main(parse_args(sys.argv[1:]))