-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsochastic_gradient_descent.py
More file actions
63 lines (49 loc) · 1.85 KB
/
sochastic_gradient_descent.py
File metadata and controls
63 lines (49 loc) · 1.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import numpy as np
from matplotlib import pyplot as plt
class LinearClassifier:
def __isMisclasified(self,x,y, w):
"""
checks if the predicted value label is misclassified
x: data point to be processed (vector)
y: the correct data label
w: the weights vector
Returns: True if the y*(w^T)x < 1 and false otherwise
"""
return y * np.dot(w,x) < 1
#misclassified w =w + n(yx - 2lambdaw) ywx < 1
#correct w = n(-2lambaw)
def svm_sgd_plot(self, features, labels):
"""
performs sochastic gradient descent to determine our model
features: our data set
labels: the classes we assigned to each data entry
Returns: our learned model
"""
#initialize our weight vector with random values (3 values)
w = np.random.rand(len(features[0]))
#The learning rate
eta = 1
#how many iterations to train for
epochs = 100000
#store misclassifications so we can plot how they change over time
errors = []
#training part, gradient descent part
for epoch in range(1,epochs):
error = 0
for ind, x in enumerate(features):
#misclassification
if (self.__isMisclasified(x,labels[ind],w)):
w = w + eta *((labels[ind]* x )+(-2*(1/epoch)*w))
error =1
else:
#correct classification, update our weights
w = w + eta *(-2* (1/epoch)*w)
errors.append(error)
#Lets plot the rate of classification errors during training for our SVM
plt.plot(errors, "|")
plt.ylim(0.5,1.5)
plt.axes().set_yticklabels([])
plt.xlabel("Epoch")
plt.ylabel("Misclassified")
plt.show()
return w