增强学习--策略迭代

策略迭代

实例代码

 1 class PolicyIteration:
 2     def __init__(self, env):
 3         self.env = env
 4         # 2-d list for the value function
 5         self.value_table = [[0.0] * env.width for _ in range(env.height)]#值函数表
 6         # list of random policy (same probability of up, down, left, right)
 7         self.policy_table = [[[0.25, 0.25, 0.25, 0.25]] * env.width
 8                                     for _ in range(env.height)]#每一状态的动作策略表,一开始向四方运动是相同概率的
 9         # setting terminal state
10         self.policy_table[2][2] = []#吸收态,终止
11         self.discount_factor = 0.9
12 
13     def policy_evaluation(self):#策略估计
14         next_value_table = [[0.00] * self.env.width
15                                     for _ in range(self.env.height)]
16 
17         # Bellman Expectation Equation for the every states
18         for state in self.env.get_all_states():
19             value = 0.0
20             # keep the value function of terminal states as 0(吸收态赋0)
21             if state == [2, 2]:
22                 next_value_table[state[0]][state[1]] = value
23                 continue
24 
25             for action in self.env.possible_actions:#计算所有可能动作
26                 next_state = self.env.state_after_action(state, action)
27                 reward = self.env.get_reward(state, action)
28                 next_value = self.get_value(next_state)
29                 value += (self.get_policy(state)[action] *
30                           (reward + self.discount_factor * next_value))
31 
32             next_value_table[state[0]][state[1]] = round(value, 2)
33 
34         self.value_table = next_value_table
35 
36     def policy_improvement(self):#策略改进
37         next_policy = self.policy_table
38         for state in self.env.get_all_states():
39             if state == [2, 2]:
40                 continue
41             value = -99999
42             max_index = []
43             result = [0.0, 0.0, 0.0, 0.0]  # initialize the policy
44 
45             # for every actions, calculate 计算所有可能动作,保留取得最大值函数的动作
46             # [reward + (discount factor) * (next state value function)]
47             for index, action in enumerate(self.env.possible_actions):
48                 next_state = self.env.state_after_action(state, action)
49                 reward = self.env.get_reward(state, action)
50                 next_value = self.get_value(next_state)
51                 temp = reward + self.discount_factor * next_value
52 
53                 # We normally can't pick multiple actions in greedy policy.
54                 # but here we allow multiple actions with same max values 允许多个取最大值函数的动作存在
55                 if temp == value:
56                     max_index.append(index)
57                 elif temp > value:
58                     value = temp
59                     max_index.clear()
60                     max_index.append(index)
61 
62             # probability of action
63             prob = 1 / len(max_index)
64 
65             for index in max_index:
66                 result[index] = prob
67 
68             next_policy[state[0]][state[1]] = result#更新策略表
69 
70         self.policy_table = next_policy
71 
72     # get action according to the current policy
73     def get_action(self, state):
74         random_pick = random.randrange(100) / 100
75 
76         policy = self.get_policy(state)
77         policy_sum = 0.0
78         # return the action in the index
79         for index, value in enumerate(policy):
80             policy_sum += value
81             if random_pick < policy_sum:
82                 return index
83 
84     # get policy of specific state
85     def get_policy(self, state):
86         if state == [2, 2]:
87             return 0.0
88         return self.policy_table[state[0]][state[1]]
89 
90     def get_value(self, state):
91         return round(self.value_table[state[0]][state[1]], 2)
原文地址:https://www.cnblogs.com/buyizhiyou/p/10250082.html