-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy patheval.py
More file actions
87 lines (56 loc) · 2.01 KB
/
eval.py
File metadata and controls
87 lines (56 loc) · 2.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
from os.path import join
import pickle
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import accuracy_score
import argparse
def load_data(file_path):
f = open(file_path, 'rb')
word = pickle.load(f)
stem = pickle.load(f)
bin = pickle.load(f)
f.close()
return word, stem, bin
def predict(output, length):
prediction = []
# input('Continute: ')
for i in range(length):
if output[i][0][0] > output[i][0][1]:
prediction.append(0)
else:
prediction.append(1)
# print(output[i][0][0], output[i][0][1])
return prediction
def evaluate(output_path, printresult=True):
source_data = 'data/parsedata/'
files = os.listdir(output_path)
total_target = []
total_prediction = []
for fname in files:
names = fname.split('.')
if (len(names) != 3):
continue
if names[2] == 'data':
foutput = open(output_path + '/' + fname, 'rb')
output = pickle.load(foutput)
word, stem, bin = load_data(source_data + fname)
l = len(word)
prediction = predict(output, l)
total_target = total_target + bin
total_prediction = total_prediction + prediction
pre_recall_fscore = precision_recall_fscore_support(total_target, total_prediction)
acc = accuracy_score(total_target, total_prediction)
if printresult:
print(pre_recall_fscore[2], acc)
return (pre_recall_fscore, acc)
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--id',help='Experiment id',type = str, required = True)
parser.add_argument('-e', '--epoch', help='Epoch',type = str)
arg = parser.parse_args()
if arg.epoch ==None:
for epoch in range(100):
epoch_path = join(arg.id, str(epoch))
print(epoch_path)
evaluate(epoch_path)
else:
evaluate(join(arg.id, arg.epoch))