forked from ntumlgroup/LibMultiLabel
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlinear_trainer.py
More file actions
121 lines (107 loc) · 5.02 KB
/
linear_trainer.py
File metadata and controls
121 lines (107 loc) · 5.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import logging
from math import ceil
import numpy as np
from tqdm import tqdm
import libmultilabel.linear as linear
from libmultilabel.common_utils import dump_log, is_multiclass_dataset
from libmultilabel.linear.tree import EnsembleTreeModel, TreeModel, train_ensemble_tree
from libmultilabel.linear.utils import LINEAR_TECHNIQUES
def linear_test(config, model, datasets, label_mapping):
metrics = linear.get_metrics(config.monitor_metrics, datasets["test"]["y"].shape[1], multiclass=model.multiclass)
num_instance = datasets["test"]["x"].shape[0]
k = config.save_k_predictions
if k > 0:
labels = np.zeros((num_instance, k), dtype=object)
scores = np.zeros((num_instance, k), dtype="d")
else:
labels = []
scores = []
predict_kwargs = {}
if isinstance(model, (TreeModel, EnsembleTreeModel)):
predict_kwargs["beam_width"] = config.beam_width
for i in tqdm(range(ceil(num_instance / config.eval_batch_size))):
slice = np.s_[i * config.eval_batch_size : (i + 1) * config.eval_batch_size]
preds = model.predict_values(datasets["test"]["x"][slice], **predict_kwargs)
target = datasets["test"]["y"][slice].toarray()
metrics.update(preds, target)
if k > 0:
labels[slice], scores[slice] = linear.get_topk_labels(preds, label_mapping, config.save_k_predictions)
elif config.save_positive_predictions:
res = linear.get_positive_labels(preds, label_mapping)
labels.append(res[0])
scores.append(res[1])
metric_dict = metrics.compute()
return metric_dict, labels, scores
def linear_train(datasets, config):
# detect task type
multiclass = is_multiclass_dataset(datasets["train"], "y")
# train
if config.linear_technique == "tree":
if multiclass:
raise ValueError("Tree model should only be used with multilabel datasets.")
if config.tree_ensemble_models > 1:
model = train_ensemble_tree(
datasets["train"]["y"],
datasets["train"]["x"],
options=config.liblinear_options,
K=config.tree_degree,
dmax=config.tree_max_depth,
n_trees=config.tree_ensemble_models,
seed=config.seed,
)
else:
model = LINEAR_TECHNIQUES[config.linear_technique](
datasets["train"]["y"],
datasets["train"]["x"],
options=config.liblinear_options,
K=config.tree_degree,
dmax=config.tree_max_depth,
)
else:
model = LINEAR_TECHNIQUES[config.linear_technique](
datasets["train"]["y"],
datasets["train"]["x"],
multiclass=multiclass,
options=config.liblinear_options,
)
return model
def linear_run(config):
if config.seed is not None:
np.random.seed(config.seed)
if config.eval:
preprocessor, model = linear.load_pipeline(config.checkpoint_path)
datasets = linear.load_dataset(config.data_format, config.training_file, config.test_file)
datasets = preprocessor.transform(datasets)
else:
preprocessor = linear.Preprocessor(config.include_test_labels, config.remove_no_label_data)
datasets = linear.load_dataset(
config.data_format,
config.training_file,
config.test_file,
config.label_file,
)
datasets = preprocessor.fit_transform(datasets)
model = linear_train(datasets, config)
linear.save_pipeline(config.checkpoint_dir, preprocessor, model)
if config.test_file is not None:
assert not (
config.save_positive_predictions and config.save_k_predictions > 0
), """
If save_k_predictions is larger than 0, only top k labels are saved.
Save all labels with decision value larger than 0 by using save_positive_predictions and save_k_predictions=0."""
metric_dict, labels, scores = linear_test(config, model, datasets, preprocessor.label_mapping)
dump_log(config=config, metrics=metric_dict, split="test", log_path=config.log_path)
print(linear.tabulate_metrics(metric_dict, "test"))
if config.save_k_predictions > 0:
with open(config.predict_out_path, "w") as fp:
for label, score in zip(labels, scores):
out_str = " ".join([f"{i}:{s:.4}" for i, s in zip(label, score)])
fp.write(out_str + "\n")
logging.info(f"Saved predictions to: {config.predict_out_path}")
elif config.save_positive_predictions:
with open(config.predict_out_path, "w") as fp:
for batch_labels, batch_scores in zip(labels, scores):
for label, score in zip(batch_labels, batch_scores):
out_str = " ".join([f"{i}:{s:.4}" for i, s in zip(label, score)])
fp.write(out_str + "\n")
logging.info(f"Saved predictions to: {config.predict_out_path}")