-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
103 lines (74 loc) · 2.81 KB
/
utils.py
File metadata and controls
103 lines (74 loc) · 2.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import numpy as np
import matplotlib.pyplot as plt
import sklearn as sk
import pandas as pd
from sklearn.metrics import confusion_matrix
def plot_figure(training, validation, start, stop, num_point, name=None, ylabel=None, legend=None):
x = np.linspace(start=start, stop=stop, num=num_point)
plt.figure()
plt.plot(x, training)
plt.plot(x, validation)
plt.xlabel('regularization')
plt.ylabel(ylabel)
plt.legend(legend)
plt.savefig(name)
def calculate_profit(val_gt_profit, val_cls_pred, val_cls_gt):
"""
to calculate the profit by the cls regression result
Args:
val_gt_profit: the pure profit of a marketing
val_cls_pred: the cls prediction results.
"""
predict_true = np.where((val_cls_pred == 1))
profit = np.sum(val_gt_profit[predict_true])
return profit
def plot_figure2(train_data, val_data, start, stop, num_point, legends, xlabels, ylabels, save_path):
"""
Plot the training and validation figure in this files.
"""
x = np.linspace(start=start, stop=stop, num=num_point)
plt.figure()
plt.plot(x, train_data)
plt.plot(x, val_data)
plt.legend(legends)
plt.xlabel(xlabels)
plt.ylabel(ylabels)
plt.savefig(save_path)
def pure_profit(cls_pred, cls_gt, profit_gt):
"""
profit_gt: the profit of of every marketing
cls_pred: the classification result of the models
cls_gt: the groundtruth classification
"""
pure_profit_gt = profit_gt - 30
one_index = np.where((cls_pred == 1))[0]
pure_profit_variable = np.sum(pure_profit_gt[one_index])
return pure_profit_variable
def classification_fusion(cls_pred, reg_pred, thres_high, thres_low):
"""
This function calculate the synthesis classification results.
The final result will take regression result as consideration.
If the reg_pred[i] > thres_high, we will change cls_pred[i] = 1
If the reg_pred[i] < thres_low, we will change cls_pred[i] = 0
"""
synthesis_high = np.where((np.array(reg_pred > thres_high).astype(np.int16) == 1))[0]
synthesis_low = np.where((np.array(reg_pred < thres_low).astype(np.int16) == 1))[0]
cls_pred[synthesis_high] = 1
cls_pred[synthesis_low] = 0
return cls_pred
def recall_cls(cls_pred, cls_gt):
"""
To compute th recall
:param cls_pred: (N,)
:param cls_gt: (N,)
:return:
"""
total_recall = np.array(cls_pred == cls_gt, dtype=np.float32).mean()
rec_recall = np.array(cls_gt * cls_pred).sum() / len(np.where((cls_gt == 1))[0])
return total_recall, rec_recall
def confusion_matrix_compute(cls_pred, cls_gt):
total_acc = np.array(cls_pred == cls_gt, dtype=np.float32).mean()
tn, fp, fn, tp = confusion_matrix(cls_gt, cls_pred).ravel()
pos_acc = tp/float(fn+tp)
neg_acc = tn/float(tn+fp)
return total_acc, pos_acc, neg_acc