策略迭代
实例代码
1 class PolicyIteration:
2 def __init__(self, env):
3 self.env = env
4 # 2-d list for the value function
5 self.value_table = [[0.0] * env.width for _ in range(env.height)]#值函数表
6 # list of random policy (same probability of up, down, left, right)
7 self.policy_table = [[[0.25, 0.25, 0.25, 0.25]] * env.width
8 for _ in range(env.height)]#每一状态的动作策略表,一开始向四方运动是相同概率的
9 # setting terminal state
10 self.policy_table[2][2] = []#吸收态,终止
11 self.discount_factor = 0.9
12
13 def policy_evaluation(self):#策略估计
14 next_value_table = [[0.00] * self.env.width
15 for _ in range(self.env.height)]
16
17 # Bellman Expectation Equation for the every states
18 for state in self.env.get_all_states():
19 value = 0.0
20 # keep the value function of terminal states as 0(吸收态赋0)
21 if state == [2, 2]:
22 next_value_table[state[0]][state[1]] = value
23 continue
24
25 for action in self.env.possible_actions:#计算所有可能动作
26 next_state = self.env.state_after_action(state, action)
27 reward = self.env.get_reward(state, action)
28 next_value = self.get_value(next_state)
29 value += (self.get_policy(state)[action] *
30 (reward + self.discount_factor * next_value))
31
32 next_value_table[state[0]][state[1]] = round(value, 2)
33
34 self.value_table = next_value_table
35
36 def policy_improvement(self):#策略改进
37 next_policy = self.policy_table
38 for state in self.env.get_all_states():
39 if state == [2, 2]:
40 continue
41 value = -99999
42 max_index = []
43 result = [0.0, 0.0, 0.0, 0.0] # initialize the policy
44
45 # for every actions, calculate 计算所有可能动作,保留取得最大值函数的动作
46 # [reward + (discount factor) * (next state value function)]
47 for index, action in enumerate(self.env.possible_actions):
48 next_state = self.env.state_after_action(state, action)
49 reward = self.env.get_reward(state, action)
50 next_value = self.get_value(next_state)
51 temp = reward + self.discount_factor * next_value
52
53 # We normally can't pick multiple actions in greedy policy.
54 # but here we allow multiple actions with same max values 允许多个取最大值函数的动作存在
55 if temp == value:
56 max_index.append(index)
57 elif temp > value:
58 value = temp
59 max_index.clear()
60 max_index.append(index)
61
62 # probability of action
63 prob = 1 / len(max_index)
64
65 for index in max_index:
66 result[index] = prob
67
68 next_policy[state[0]][state[1]] = result#更新策略表
69
70 self.policy_table = next_policy
71
72 # get action according to the current policy
73 def get_action(self, state):
74 random_pick = random.randrange(100) / 100
75
76 policy = self.get_policy(state)
77 policy_sum = 0.0
78 # return the action in the index
79 for index, value in enumerate(policy):
80 policy_sum += value
81 if random_pick < policy_sum:
82 return index
83
84 # get policy of specific state
85 def get_policy(self, state):
86 if state == [2, 2]:
87 return 0.0
88 return self.policy_table[state[0]][state[1]]
89
90 def get_value(self, state):
91 return round(self.value_table[state[0]][state[1]], 2)