-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlearn.m
More file actions
96 lines (83 loc) · 2.46 KB
/
learn.m
File metadata and controls
96 lines (83 loc) · 2.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
function learn = learn(cond,us,values,rule,salient)
% function update rule for learning values, cond is binary coding for
% presence of stimuli on the current trial
rate = 0.5;
salience = ones(size(cond));
% assume total salience equally divided amongst all elements in the cond.
if salient
if sum(cond)
salience = salience./sum(cond);
end
end
% alternative salience calculation such that higher value stimuli have
% higher salience
% if salient
% if sum(cond)
% salience = (salience+values)./sum(cond);
% end
% end
increment=zeros(1,length(values));
if strcmp(rule,'basic')
for i=1:length(cond)
increment(i)= rate.*salience(i).*cond(i)*(us - values(i));
end
learn = values + increment;
end
if strcmp(rule,'template')
if us % difference in update rule when reinforced
for i=1:length(cond)
if cond(i)
increment(i)= rate.*salience(i).*cond(i).*(us - values(i));
else % decay for inactive elements on reinforcement
increment(i)=-.2;
% increment(i)= rate.*salience(i).*(0 - values(i));
end
end
% else
% for i=1:length(cond)
% increment(i)= rate.*salience(i).*cond(i)*(us - values(i));
% end
end
learn = values + increment;
% change any negative weights to 0
neg = find(learn<0);
learn(neg)=0;
end
if strcmp(rule,'rw')
for i=1:length(cond)
increment(i)= rate.*salience(i).*cond(i).*(us - sum(values.*cond));
end
learn = values + increment;
end
% if strcmp(rule,'template')
% if us % update only when rewarded
% for i=1:length(cond)
% if cond(i)
% increment(i)= rate.*salience(i).*cond(i).*(us - sum((values.*cond)));
% else
% increment(i)=-0.2;
% end
% end
% learn = values + increment;
% % change any negative weights to 0
% neg = find(learn<0);
% learn(neg)=0;
% % normalise
% if sum(learn) % non-zero sum
% learn=learn./sum(learn);
% end
% else learn = values;
% end
% end
if strcmp(rule,'neuromod')
increase = 0.2; % amount of increase each time rewarded
if us % update only when rewarded
for i=1:length(cond)
if cond(i)
increment(i)= salience(i).*increase;
end
end
learn = values + increment;
else learn = values;
end
end