增强学习--Q-leraning

Q-learning

实例代码

 1 import numpy as np
 2 import random
 3 from environment import Env
 4 from collections import defaultdict
 5 
 6 class QLearningAgent:
 7     def __init__(self, actions):
 8         # actions = [0, 1, 2, 3]
 9         self.actions = actions
10         self.learning_rate = 0.01
11         self.discount_factor = 0.9
12         self.epsilon = 0.1
13         self.q_table = defaultdict(lambda: [0.0, 0.0, 0.0, 0.0])#待更新q表
14 
15     # update q function with sample <s, a, r, s'>
16     def learn(self, state, action, reward, next_state):
17         current_q = self.q_table[state][action]
18         # using Bellman Optimality Equation to update q function
19         new_q = reward + self.discount_factor * max(self.q_table[next_state])
20         self.q_table[state][action] += self.learning_rate * (new_q - current_q)#更新公式,off-policy
21 
22     # get action for the state according to the q function table
23     # agent pick action of epsilon-greedy policy
24     def get_action(self, state):
25         #epsilon-greedy policy
26         if np.random.rand() < self.epsilon:
27             # take random action
28             action = np.random.choice(self.actions)
29         else:
30             # take action according to the q function table
31             state_action = self.q_table[state]
32             action = self.arg_max(state_action)
33         return action
34 
35     @staticmethod
36     def arg_max(state_action):
37         max_index_list = []
38         max_value = state_action[0]
39         for index, value in enumerate(state_action):
40             if value > max_value:
41                 max_index_list.clear()
42                 max_value = value
43                 max_index_list.append(index)
44             elif value == max_value:
45                 max_index_list.append(index)
46         return random.choice(max_index_list)
47 
48 if __name__ == "__main__":
49     env = Env()
50     agent = QLearningAgent(actions=list(range(env.n_actions)))
51 
52     for episode in range(1000):
53         state = env.reset()
54 
55         while True:
56             env.render()
57 
58             # take action and proceed one step in the environment
59             action = agent.get_action(str(state))
60             next_state, reward, done = env.step(action)
61 
62             # with sample <s,a,r,s'>, agent learns new q function
63             agent.learn(str(state), action, reward, str(next_state))
64 
65             state = next_state
66             env.print_value_all(agent.q_table)
67 
68             # if episode ends, then break
69             if done:
70                 break
原文地址:https://www.cnblogs.com/buyizhiyou/p/10250121.html