增强学习--值迭代

值迭代

实例代码

 1 class ValueIteration:
 2     def __init__(self, env):
 3         self.env = env
 4         # 2-d list for the value function
 5         self.value_table = [[0.0] * env.width for _ in range(env.height)]
 6         self.discount_factor = 0.9
 7 
 8     # get next value function table from the current value function table
 9     def value_iteration(self):
10         next_value_table = [[0.0] * self.env.width
11                                     for _ in range(self.env.height)]
12         for state in self.env.get_all_states():
13             if state == [2, 2]:
14                 next_value_table[state[0]][state[1]] = 0.0
15                 continue
16             value_list = []
17 
18             for action in self.env.possible_actions:
19                 next_state = self.env.state_after_action(state, action)
20                 reward = self.env.get_reward(state, action)
21                 next_value = self.get_value(next_state)
22                 value_list.append((reward + self.discount_factor * next_value))
23             # return the maximum value(it is the optimality equation!!)
24             next_value_table[state[0]][state[1]] = round(max(value_list), 2)#每一次更新值函数表时取最大回报的动作更新
25         self.value_table = next_value_table
26 
27     # get action according to the current value function table
28     def get_action(self, state):
29         import pdb; pdb.set_trace()
30         action_list = []
31         max_value = -99999
32 
33         if state == [2, 2]:
34             return []
35 
36         # calculating q values for the all actions and
37         # append the action to action list which has maximum q value
38         for action in self.env.possible_actions:
39 
40             next_state = self.env.state_after_action(state, action)
41             reward = self.env.get_reward(state, action)
42             next_value = self.get_value(next_state)
43             value = (reward + self.discount_factor * next_value)
44 
45             if value > max_value:
46                 action_list.clear()
47                 action_list.append(action)
48                 max_value = value
49             elif value == max_value:
50                 action_list.append(action)
51 
52         return action_list
53 
54     def get_value(self, state):
55         return round(self.value_table[state[0]][state[1]], 2)
原文地址:https://www.cnblogs.com/buyizhiyou/p/10250089.html