-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval_deep_model.py
More file actions
130 lines (110 loc) · 3.73 KB
/
eval_deep_model.py
File metadata and controls
130 lines (110 loc) · 3.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import argparse
import re
import os
from collections import Counter
import torch
from torch.utils.data import DataLoader
from utils.config import *
from utils.train_deep_model_utils import json_file
from utils.timeseries_dataset import read_files, create_splits
from utils.evaluator import Evaluator
def eval_deep_model(
data_path,
model_name,
model_path=None,
model_parameters_file=None,
path_save=None,
fnames=None,
read_from_file=None,
model=None
):
"""Given a model and some data it predicts the time series given
:param data_path:
:param model_name:
:param model_path:
:param model_parameters_file:
:param path_save:
:param fnames:
"""
window_size = int(re.search(r'\d+', str(data_path)).group())
batch_size = 128
assert(
(model is not None) or \
(model_path is not None and model_parameters_file is not None)
), "You should provide the model or the path to the model, not both"
assert(
not (fnames is not None and read_from_file is not None)
), "You should provide either the fnames or the path to the specific splits, not both"
# Load the model only if not provided
if model == None:
# Read models parameters
model_parameters = json_file(model_parameters_file)
# Load model
model = deep_models[model_name](**model_parameters)
model.load_state_dict(torch.load(model_path))
model.eval()
model.to('cuda')
# Load the splits
if read_from_file is not None:
_, val_set, test_set = create_splits(
data_path,
read_from_file=read_from_file,
)
fnames = test_set if len(test_set) > 0 else val_set
# fnames = fnames[:100]
else:
# Read data (single csv file or directory with csvs)
if '.csv' == data_path[-len('.csv'):]:
tmp_fnames = [data_path.split('/')[-1]]
data_path = data_path.split('/')[:-1]
data_path = '/'.join(data_path)
else:
tmp_fnames = read_files(data_path)
# Keep specific time series if fnames is given
if fnames is not None:
fnames_len = len(fnames)
fnames = [x for x in tmp_fnames if x in fnames]
if len(fnames) != fnames_len:
raise ValueError("The data path does not include the time series in fnames")
else:
fnames = tmp_fnames
# Evaluate model
evaluator = Evaluator()
classifier_name = f"{model_name}_{window_size}"
results = evaluator.predict(
model=model,
fnames=fnames,
data_path=data_path,
deep_model=True,
)
results = results.sort_index()
results.columns = [f"{classifier_name}_{x}" for x in results.columns.values]
# Print results
print(results)
counter = Counter(results[f"{model_name}_{window_size}_class"])
print(dict(counter))
# Save the results
if path_save is not None:
file_name = os.path.join(path_save, f"{classifier_name}_preds.csv")
results.to_csv(file_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='Evaluate deep learning models',
description='Evaluate all deep learning architectures on a single or multiple time series \
and save the results'
)
parser.add_argument('-d', '--data', type=str, help='path to the time series to predict', required=True)
parser.add_argument('-m', '--model', type=str, help='model to run', required=True)
parser.add_argument('-mp', '--model_path', type=str, help='path to the trained model', required=True)
parser.add_argument('-pa', '--params', type=str, help="a json file with the model's parameters", required=True)
parser.add_argument('-ps', '--path_save', type=str, help='path to save the results', default="results/raw_predictions")
parser.add_argument('-f', '--file', type=str, help='path to file that contains a specific split', default=None)
args = parser.parse_args()
eval_deep_model(
data_path=args.data,
model_name=args.model,
model_path=args.model_path,
model_parameters_file=args.params,
path_save=args.path_save,
read_from_file=args.file
)