1 import sys
2 import gym
3 import pylab
4 import numpy as np
5 from keras.layers import Dense
6 from keras.models import Sequential
7 from keras.optimizers import Adam
8
9 EPISODES = 1000
10
11
12 # A2C(Advantage Actor-Critic) agent for the Cartpole
13 # actor-critic算法结合了value-based和policy-based方法
14 class A2CAgent:
15 def __init__(self, state_size, action_size):
16 # if you want to see Cartpole learning, then change to True
17 self.render = True
18 self.load_model = False
19 # get size of state and action
20 self.state_size = state_size
21 self.action_size = action_size
22 self.value_size = 1
23
24 # These are hyper parameters for the Policy Gradient
25 self.discount_factor = 0.99
26 self.actor_lr = 0.001
27 self.critic_lr = 0.005
28
29 # create model for policy network
30 self.actor = self.build_actor()
31 self.critic = self.build_critic()
32
33 if self.load_model:
34 self.actor.load_weights("./save_model/cartpole_actor.h5")
35 self.critic.load_weights("./save_model/cartpole_critic.h5")
36
37 # approximate policy and value using Neural Network
38 # actor: state is input and probability of each action is output of model
39 def build_actor(self):#actor网络:state-->action
40 actor = Sequential()
41 actor.add(Dense(24, input_dim=self.state_size, activation='relu',
42 kernel_initializer='he_uniform'))
43 actor.add(Dense(self.action_size, activation='softmax',
44 kernel_initializer='he_uniform'))
45 actor.summary()
46 # See note regarding crossentropy in cartpole_reinforce.py
47 actor.compile(loss='categorical_crossentropy',
48 optimizer=Adam(lr=self.actor_lr))
49 return actor
50
51 # critic: state is input and value of state is output of model
52 def build_critic(self):#critic网络:state-->value,Q值
53 critic = Sequential()
54 critic.add(Dense(24, input_dim=self.state_size, activation='relu',
55 kernel_initializer='he_uniform'))
56 critic.add(Dense(self.value_size, activation='linear',
57 kernel_initializer='he_uniform'))
58 critic.summary()
59 critic.compile(loss="mse", optimizer=Adam(lr=self.critic_lr))
60 return critic
61
62 # using the output of policy network, pick action stochastically
63 def get_action(self, state):
64 policy = self.actor.predict(state, batch_size=1).flatten()#根据actor网络预测下一步动作
65 return np.random.choice(self.action_size, 1, p=policy)[0]
66
67 # update policy network every episode
68 def train_model(self, state, action, reward, next_state, done):
69 target = np.zeros((1, self.value_size))#(1,1)
70 advantages = np.zeros((1, self.action_size))#(1, 2)
71
72 value = self.critic.predict(state)[0]#critic网络预测的当前q值
73 next_value = self.critic.predict(next_state)[0]#critic网络预测的下一个q值
74
75 '''
76 理解下面部分
77 '''
78 if done:
79 advantages[0][action] = reward - value
80 target[0][0] = reward
81 else:
82 advantages[0][action] = reward + self.discount_factor * (next_value) - value#acotr网络
83 target[0][0] = reward + self.discount_factor * next_value#critic网络
84
85 self.actor.fit(state, advantages, epochs=1, verbose=0)
86 self.critic.fit(state, target, epochs=1, verbose=0)
87
88
89 if __name__ == "__main__":
90 # In case of CartPole-v1, maximum length of episode is 500
91 env = gym.make('CartPole-v1')
92 # get size of state and action from environment
93 state_size = env.observation_space.shape[0]
94 action_size = env.action_space.n
95
96 # make A2C agent
97 agent = A2CAgent(state_size, action_size)
98 scores, episodes = [], []
99
100 for e in range(EPISODES):
101 done = False
102 score = 0
103 state = env.reset()
104 state = np.reshape(state, [1, state_size])
105
106 while not done:
107 if agent.render:
108 env.render()
109
110 action = agent.get_action(state)
111 next_state, reward, done, info = env.step(action)
112 next_state = np.reshape(next_state, [1, state_size])
113 # if an action make the episode end, then gives penalty of -100
114 reward = reward if not done or score == 499 else -100
115
116 agent.train_model(state, action, reward, next_state, done)#每执行一次action训练一次
117
118 score += reward
119 state = next_state
120
121 if done:
122 # every episode, plot the play time
123 score = score if score == 500.0 else score + 100
124 scores.append(score)
125 episodes.append(e)
126 pylab.plot(episodes, scores, 'b')
127 pylab.savefig("./save_graph/cartpole_a2c.png")
128 print("episode:", e, " score:", score)
129
130 # if the mean of scores of last 10 episode is bigger than 490
131 # stop training
132 if np.mean(scores[-min(10, len(scores)):]) > 490:
133 sys.exit()
134
135 # save the model
136 if e % 50 == 0:
137 agent.actor.save_weights("./save_model/cartpole_actor.h5")
138 agent.critic.save_weights("./save_model/cartpole_critic.h5")