-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathAgent.java
More file actions
95 lines (93 loc) · 2.56 KB
/
Agent.java
File metadata and controls
95 lines (93 loc) · 2.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
public class Agent{
/* environment */
private Environment env;
/* discount rate */
private static final double GAMMA = 0.9;
/* learning rate */
private static final double ALPHA = 0.1;
/* number of search */
private static final int TMAX = 1000000;
/* Q value */
private double[][] Q = {{0.0, 0.0},
{0.0, 0.0},
{0.0, 0.0},
{0.0, 0.0}};
/* constructor */
public Agent(Environment env){
/* set envirionment */
this.env = env;
}
/* learning agent */
public void learn(){
/* current state */
int state = 0;
/* act until TMAX */
for(int t = 0; t < TMAX; t++){
/* choose action */
int action = select_action();
/* observe next state */
int next_state = env.observe_state(state, action);
/* observe reward */
int reward = env.observe_reward(state, action);
/* max Q value at next state */
double next_Q_max = max(Q[next_state]);
/* update Q value */
Q[state][action] =
(1 -ALPHA) * Q[state][action] + ALPHA * (reward + GAMMA * next_Q_max);
/* transition to next state */
state = next_state;
}
}
/* choose action by epsilon-greedy method */
private int select_action(){
/* return 0 or 1 */
return new java.util.Random().nextInt(2);
}
/* test agent */
public void test(){
/* total reward */
int total_reward = 0;
/* current state */
int state = 0;
/* act until TMAX */
for(int t = 0; t < TMAX; t++){
/* choose action */
int action = argmax(Q[state]);
/* observe next state */
int next_state = env.observe_state(state, action);
/* observe reward */
int reward = env.observe_reward(state, action);
/* transition to next state */
state = next_state;
/* add reward */
total_reward += reward;
}
/* output reward */
System.out.println("total reward");
System.out.println(total_reward);
}
/* output Q value */
public void print_Q(){
System.out.println("Q Value");
for(int i = 0; i < Q.length; i++){
for(int j = 0; j < Q[i].length; j++)
System.out.print(Q[i][j]+" ");
System.out.println();
}
}
/* find max value */
public double max(double[] array){
return array[argmax(array)];
}
/* find index of max value */
public int argmax(double[] array){
int max_index = 0;
double max = array[0];
for(int i = 0; i < array.length; i++)
if(array[i] > max){
max = array[i];
max_index = i;
}
return max_index;
}
}