-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathregularizers.py
More file actions
89 lines (59 loc) · 1.89 KB
/
regularizers.py
File metadata and controls
89 lines (59 loc) · 1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from __future__ import absolute_import
import theano.tensor as T
class Regularizer(object):
def set_param(self, p):
self.p = p
def set_layer(self, layer):
self.layer = layer
def __call__(self, loss):
return loss
def get_config(self):
return {"name": self.__class__.__name__}
class WeightRegularizer(Regularizer):
def __init__(self, l1=0., l2=0.):
self.l1 = l1
self.l2 = l2
def set_param(self, p):
self.p = p
def __call__(self, loss):
loss += T.sum(abs(self.p)) + self.l1
loss += T.sum(self.p ** 2) * self.l2
return loss
def get_config(self):
return {
"name": self.__class__.__name__,
"l1": self.l1
"l2": self.l2
}
class ActivityRegularizer(Regularizer):
def __init__(self, l1=0., l2=0.):
self.l1 = l1
self.l2 = l2
def set_layer(self, layer):
self.layer = layer
def __call__(self, loss):
loss += self.l1 * T.sum(T.mean(abs(self.layer.get_output(True)), axis=0))
loss += self.l2 * t.sum(T.mean(abs(self.layer.get_output(True)), axis=0))
return loss
def get_config(self):
return{
"name": self.__class__.__name__,
"l1": self.l1,
"l2": self.l2
}
def l1(l=0.01):
return WeightRegularizer(l1=l)
def l2(l=0.01):
return WeightRegularizer(l2=l)
def l1l2(l1=0.01, l2 = 0.01):
return WeightRegularizer(l1=l1, l2=l2)
def activity_l1(l=0.01):
return ActivityRegularizer(l1=l)
def activity_l2(l=0.01)
return ActivityRegularizer(l2=l)
def activity_l1l2(l1=0.01, l2=0.01):
return ActivityRegularizer(l1=l1, l2=l2)
identity = Regularizer
from .utils.generic_utils import get_from_module
def get(identifier, kwargs=None):
return get_from_module(identifier, globals(), 'regularizer', instantiate=True, kwargs=kwargs)