-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvalidation.js
More file actions
42 lines (33 loc) · 1.17 KB
/
validation.js
File metadata and controls
42 lines (33 loc) · 1.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
const {mean} = require('./utils');
class CrossValidator {
/**
* @param {!Classifier} model
* @param {!Array<!Array<*>>} data
* @param {!Array<*>} labels
* @param {!Number} k
*/
constructor(model, data, labels, k = 10) {
this.k = k;
this.data = data;
this.labels = labels;
this.model = model;
}
score() {
const step = Math.floor(this.data.length / this.k);
const scores = [];
for (let i = 0; i < this.k; i++) {
const testingData = this.data.slice(i * step, (i + 1) * step);
const testingLabels = this.labels.slice(i * step, (i + 1) * step);
const trainingData = this.data.slice(0, i * step);
this.data.slice((i + 1) * step, this.data.length).forEach(d => trainingData.push(d));
const trainingLabels = this.labels.slice(0, i * step);
this.labels.slice((i + 1) * step, this.data.length).forEach(l => trainingLabels.push(l));
const model = new this.model(trainingData, trainingLabels);
scores.push(
testingData.filter((d, idx) => model.predict(d) === testingLabels[idx]).length
/ testingData.length);
}
return mean(scores);
}
}
module.exports = {CrossValidator};