增强学习--Sarsa算法

Sarsa算法

实例代码

 1 import numpy as np
 2 import random
 3 from collections import defaultdict
 4 from environment import Env
 5 
 6 
 7 # SARSA agent learns every time step from the sample <s, a, r, s', a'>
 8 class SARSAgent:
 9     def __init__(self, actions):
10         self.actions = actions
11         self.learning_rate = 0.01
12         self.discount_factor = 0.9
13         self.epsilon = 0.1
14         self.q_table = defaultdict(lambda: [0.0, 0.0, 0.0, 0.0])#动作值函数表,q表,要更新的表,不同于mc的更新v表
15 
16     # with sample <s, a, r, s', a'>, learns new q function
17     def learn(self, state, action, reward, next_state, next_action):
18         current_q = self.q_table[state][action]
19         next_state_q = self.q_table[next_state][next_action]
20         new_q = (current_q + self.learning_rate *
21                 (reward + self.discount_factor * next_state_q - current_q))#q表更新公式
22         self.q_table[state][action] = new_q
23 
24     # get action for the state according to the q function table
25     # agent pick action of epsilon-greedy policy
26     def get_action(self, state):#获取下一步动作
27         #epsilon-greedy policy,exploration
28         if np.random.rand() < self.epsilon:
29             # take random action
30             action = np.random.choice(self.actions)
31         else:
32             # take action according to the q function table
33             state_action = self.q_table[state]
34             action = self.arg_max(state_action)
35         return action
36 
37     @staticmethod
38     def arg_max(state_action):
39         max_index_list = []
40         max_value = state_action[0]
41         for index, value in enumerate(state_action):
42             if value > max_value:
43                 max_index_list.clear()
44                 max_value = value
45                 max_index_list.append(index)
46             elif value == max_value:
47                 max_index_list.append(index)
48         return random.choice(max_index_list)
49 
50 if __name__ == "__main__":
51     env = Env()
52     agent = SARSAgent(actions=list(range(env.n_actions)))
53 
54     for episode in range(1000):
55         # reset environment and initialize state
56 
57         state = env.reset()
58         # get action of state from agent
59         action = agent.get_action(str(state))
60 
61         while True:
62             env.render()
63 
64             # take action and proceed one step in the environment
65             next_state, reward, done = env.step(action)
66             next_action = agent.get_action(str(next_state))
67 
68             # with sample <s,a,r,s',a'>, agent learns new q function
69             agent.learn(str(state), action, reward, str(next_state), next_action)
70 
71             state = next_state
72             action = next_action
73 
74             # print q function of all states at screen
75             env.print_value_all(agent.q_table)
76 
77             # if episode ends, then break
78             if done:
79                 break
原文地址:https://www.cnblogs.com/buyizhiyou/p/10250114.html