-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathMLP.cpp
More file actions
138 lines (113 loc) · 4.54 KB
/
MLP.cpp
File metadata and controls
138 lines (113 loc) · 4.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#include "MLP.h"
MLP::MLP()
{
//ctor
}
MLP::MLP(int h1pNo, int h2pNo, double qtty, int maxIter, bool validationMethod)
{
//no of neurons in hidden layer
h1No = h1pNo;
h2No = h2pNo;
// param = qtty;
// this->dL = dL; //learning
// this->dT = dT; //testing
// this->decay = wDecay;
this->maxIter = maxIter;
this->kFoldValidation = validationMethod;
alglib::mlpcreatetrainercls(X.getObjectAt(0).getFeatureCount(), X.getClassCount(), trn); //qtt of input features, number of classes to be produced
double wstep = 0.000;
mlpsetdecay(trn, 0.001); // by default we set moderate weight decay
mlpsetcond(trn, wstep, this->maxIter); // * we choose iterations limit as stopping condition (another condition - step size - is zero, which means than this condition is not active)
if ((h1No > 0) && (h2No > 0))
alglib::mlpcreatec2(X.getObjectAt(0).getFeatureCount(), h1No, h2No, X.getClassCount(), network); //create nn network with noofinput features, 2 hidden layers, noofclasses (and sore to network variable)
if ((h1No > 0) && (h2No == 0))
alglib::mlpcreatec1(X.getObjectAt(0).getFeatureCount(), h1No, X.getClassCount(), network); //create nn network with no of input features, 1 hidden layer, noofclasses (and sore to network variable)
if ((h1No == 0) && (h2No == 0))
alglib::mlpcreatec0(X.getObjectAt(0).getFeatureCount(), X.getClassCount(), network); //create nn network with no of input features, 0 hidden layer, noofclasses (and sore to network variable)
///h2No must be non zero
if (this->kFoldValidation == true) //do kfold validation
{
ClusterizationMethods::initializeData();
alglib::mlpsetdataset(trn, ClusterizationMethods::learnSet, ClusterizationMethods::learnObjQtty); //attach learning data to data set
alglib::mlpkfoldcv(trn, network, 1, int(qtty), rep);
}
else
{
ClusterizationMethods::initializeData(qtty, 100 - qtty);
alglib::mlpsetdataset(trn, ClusterizationMethods::learnSet, ClusterizationMethods::learnObjQtty); //attach learning data to data set
alglib::mlptrainnetwork(trn, network, 1, rep); // train network NRestarts=1, network is trained from random initial state. With NRestarts=0, network is trained without randomization (original state is used as initial point).
alglib::integer_1d_array Subset;
Subset.setlength(10);
alglib::mlpallerrorssubset(network, testSet, testObjQtty, Subset, -1, repp);
}
//ctor
// now get network error // do not calculate cross-validation since it validates the topology of the network
}
MLP::~MLP()
{
//dtor
}
ObjectMatrix MLP::getProjection()
{
//int cols = X.getClassCount();
int ftCount = X.getObjectAt(0).getFeatureCount();
int objCount = X.getObjectCount();
initializeYMatrix(objCount, ftCount + X.getClassCount());
alglib::real_1d_array tmpYObj;
alglib::real_1d_array tmpXObj;
tmpYObj.setlength(ftCount);
tmpXObj.setlength(X.getClassCount());
DataObject tmpO;
for (int i = 0; i < objCount; i++)
{
tmpO = X.getObjectAt(i);
for (int ft = 0; ft < ftCount; ft++)
{
double feature = tmpO.getFeatureAt(ft);
tmpYObj(ft) = feature;
Y.updateDataObject(i, ft, feature);
}
alglib::mlpprocess(network, tmpYObj, tmpXObj);
double max_prob = tmpXObj(0);
int indx = 0;
for (int j = 0; j < X.getClassCount(); j++)
{
Y.updateDataObject(i, j + ftCount, tmpXObj(j));
if (max_prob < tmpXObj(j))
{
max_prob = tmpXObj(j);
indx = j;
}
}
if (tmpO.getClassLabel() != -1)
Y.updateDataObjectClass(i, tmpO.getClassLabel());
else
Y.updateDataObjectClass(i, indx);
}
std::vector <std::string > probabilities; probabilities.reserve(0);
for (int i = 0; i < X.getClassCount(); i++)
probabilities.push_back("probClass" + X.getStringClassAttributes().at(i));
Y.addAtributes(probabilities);
Y.setPrintClass(X.getStringClassAttributes());
return Y;
}
double MLP::getStress()
{
if (this->kFoldValidation)
{
return rep.avgrelerror;
}
else
{
return repp.avgrelerror;
}
//}
/*
* Rep.RelCLSError - fraction of misclassified cases.
* Rep.AvgCE - acerage cross-entropy
* Rep.RMSError - root-mean-square error
* Rep.AvgError - average error
* Rep.AvgRelError - average relative error
*/
return rep.rmserror;
}