-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathsarsa_lambda_learning.py
More file actions
57 lines (47 loc) · 2.14 KB
/
sarsa_lambda_learning.py
File metadata and controls
57 lines (47 loc) · 2.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import numpy as np
import pandas as pd
from copy import deepcopy
class SARSALambdaLearning():
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, eligibility_trace_decay=0.9, df=None):
self.actions = actions
self.lr = learning_rate
self.gamma = reward_decay
self.epsilon = e_greedy
self.eligibility_trace_decay = eligibility_trace_decay
# df is a data frame of pretrained Q table
if df is None:
self.Q = pd.DataFrame(columns=self.actions, dtype=np.float64)
else:
self.Q = df
# et is our eligibility trace matrix which we are using in q lambda learning approach
self.et = deepcopy(self.Q)
def choose_action(self, state):
self.check_state_exist(state)
rand = np.random.uniform()
if rand < self.epsilon:
# this case we have to select the beast action based on Q table
# which is the action that has the max value
state_action = self.Q.loc[state, :]
state_action = state_action.reindex(np.random.permutation(state_action.index))
action = state_action.idxmax()
else:
# in this case we let the agent to have some exploration on the environment
action = np.random.choice(self.actions)
return action
def sarsa_lambda(self, state, action, reward, next_state, next_action):
self.check_state_exist(next_state)
predict = self.Q.ix[state, action]
if next_state != 'terminal':
target = reward + self.gamma * self.Q.loc[next_state, next_action]
else:
target = reward
self.et.loc[state, action] += 1
# Q update
self.Q += self.lr * (target - predict) * self.et
# decay eligibility trace after update
self.et *= self.gamma * self.eligibility_trace_decay
def check_state_exist(self, state):
if state not in self.Q.index.astype(str):
zero_series = pd.Series([0] * len(self.actions), index=self.Q.columns, name=state)
self.Q = self.Q.append(zero_series)
self.et = self.et.append(zero_series)