forked from deeplearning4j/deeplearning4j
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcompute_acc_vecs.java
More file actions
122 lines (120 loc) · 3.69 KB
/
compute_acc_vecs.java
File metadata and controls
122 lines (120 loc) · 3.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
private void computeAccuracyOnTest(Word2Vec vec, String testFileName) {
try {
File testFile = new File(testFileName);
Scanner s = new Scanner(testFile);
int qid = 0, lineNumber = 0; // qid section number
int ccn = 0, tcn = 0,
cacn = 0, tacn = 0,
seac = 0, secn = 0,
syac = 0, sycn = 0;
int tq = 0, tqs = 0;
while(s.hasNextLine()) {
lineNumber++;
if(tqs > 3) {
break;
}
String line = s.nextLine();
String[] toks = line.split("\\s");
if(toks[0].equals(":")) {
// this indicates the start of a new section
// its label follows
// print its label, and stats so far
System.out.println("new section:");
if(line.length() > 1) {
System.out.println(line.substring(1));
}
if(qid > 0) {
System.out.printf("Accuracy top1: %.2f %% (%d / %d)\n",
ccn / ((double) tcn) * 100, ccn, tcn);
System.out.printf("Total accuracy: %.2f %% "
+ "Semantic Accuracy: %.2f %% "
+ "Syntactic Accuracy: %.2f %%\n",
cacn / ((double) tacn) * 100,
seac / ((double) secn) * 100,
syac / ((double) sycn) * 100);
}
qid++;
}
else { // else a question line in the given section
if(toks.length != 4) {
System.err.println("possible error in file format? line:"+lineNumber);
continue;
}
// uppercase all strings
for(int i = 0; i < toks.length; i++) {
toks[i] = toks[i].toUpperCase();
}
// check first 3 strings are in vocab
for(int i = 0; i < 3; i++) {
if(!vec.hasWord(toks[i])) {
continue;
}
}
tq++;
// check fourth string
if(!vec.hasWord(toks[3])) {
continue;
}
tqs++;
// vectors: A - B + C = ?
// go through vocab, find closest among words
// that aren't A, B, or C
double bestDist = 0, currDist = 0;
String bestWord = "-1";
Iterator<INDArray> vecs = cache.vectors();
INDArray[] toks_vecs = new INDArray[4];
for(int i = 0; i < 4; i++) {
toks_vecs[i] = vec.getWordVectorMatrix(toks[i]);
}
INDArray comp = toks_vecs[0].subi(toks_vecs[1]).add(toks_vecs[2]);
int index = 0;
while(vecs.hasNext()) {
String currWord = cache.wordAtIndex(index);
boolean toSkip = false;
for(int i = 0; i < 3; i++) {
if(currWord.toUpperCase().equals(toks[i])) {
toSkip = true;
}
}
if(toSkip) continue;
INDArray curr = vecs.next();
currDist = comp.distance2(curr);
if(currDist > bestDist) {
bestWord = cache.wordAtIndex(index);
bestDist = currDist;
}
index++;
}
// edit! apparently a method does this? hopefully?
// System.out.println(analogyWords.toString());
// TreeSet<VocabWord> tr = vec.analogy(toks[0], toks[1], toks[2]);
// System.out.println(tr.toString());
// TODO does this work?
// bestWord = vec.analogyWords(toks[0], toks[1], toks[2]).get(0);
if(bestWord.toUpperCase().equals(toks[3])) {
ccn++;
cacn++;
if(qid <= 5) {
seac++;
}
else{
syac++;
}
}
if(qid <= 5) {
secn++;
}
else {
sycn++;
}
tcn++;
tacn++;
}
}
System.out.printf("Questions seen / total: %d %d %.2f %%\n",
tqs, tq, tqs/ ((double) tq) * 100);
}
catch(IOException ex) {
ex.printStackTrace();
}
}