forked from harvardnlp/seq2seq-attn
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_utils.lua
More file actions
59 lines (50 loc) · 1.6 KB
/
model_utils.lua
File metadata and controls
59 lines (50 loc) · 1.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
function clone_many_times(net, T)
local clones = {}
local params, gradParams
if net.parameters then
params, gradParams = net:parameters()
if params == nil then
params = {}
end
end
local paramsNoGrad
if net.parametersNoGrad then
paramsNoGrad = net:parametersNoGrad()
end
local mem = torch.MemoryFile("w"):binary()
mem:writeObject(net)
for t = 1, T do
-- We need to use a new reader for each clone.
-- We don't want to use the pointers to already read objects.
local reader = torch.MemoryFile(mem:storage(), "r"):binary()
local clone = reader:readObject()
reader:close()
if net.parameters then
local cloneParams, cloneGradParams = clone:parameters()
local cloneParamsNoGrad
for i = 1, #params do
cloneParams[i]:set(params[i])
cloneGradParams[i]:set(gradParams[i])
end
if paramsNoGrad then
cloneParamsNoGrad = clone:parametersNoGrad()
for i =1,#paramsNoGrad do
cloneParamsNoGrad[i]:set(paramsNoGrad[i])
end
end
end
clones[t] = clone
collectgarbage()
end
mem:close()
return clones
end
function adagradStep(x, dfdx, eta, state)
if not state.var then
state.var = torch.Tensor():typeAs(x):resizeAs(x):zero()
state.std = torch.Tensor():typeAs(x):resizeAs(x)
end
state.var:addcmul(1, dfdx, dfdx)
state.std:sqrt(state.var)
x:addcdiv(-eta, dfdx, state.std:add(1e-10))
end