增强学习--蒙特卡洛方法

蒙特卡洛方法

实例代码

下面代码是constant-α monte carlo,这里有一点介绍

  1 import numpy as np
  2 import random
  3 from collections import defaultdict
  4 from environment import Env
  5 
  6 
  7 # Monte Carlo Agent which learns every episodes from the sample
  8 class MCAgent:
  9     def __init__(self, actions):
 10         self.width = 5
 11         self.height = 5
 12         self.actions = actions
 13         self.learning_rate = 0.01
 14         self.discount_factor = 0.9
 15         self.epsilon = 0.1
 16         self.samples = []
 17         self.value_table = defaultdict(float)#初始化值函数表,0
 18 
 19     # append sample to memory(state, reward, done)
 20     def save_sample(self, state, reward, done):
 21         self.samples.append([state, reward, done])
 22 
 23     # for every episode, agent updates q function of visited states
 24     def update(self):
 25         G_t = 0
 26         visit_state = []
 27         for reward in reversed(self.samples):#此处reverse,状态反转
 28             state = str(reward[0])
 29             if state not in visit_state:#first-visit MC methods
 30                 visit_state.append(state)
 31                 G_t = self.discount_factor * (reward[1] + G_t)#累积回报
 32                 value = self.value_table[state]
 33                 self.value_table[state] = (value +
 34                                            self.learning_rate * (G_t - value))
 35                 #constant-α monte carlo constant-α蒙特卡洛值函数更新
 36 
 37     # get action for the state according to the q function table
 38     # agent pick action of epsilon-greedy policy
 39     def get_action(self, state):
 40         if np.random.rand() < self.epsilon:#以epsilon概率随机选择,Exploration
 41             # take random action
 42             action = np.random.choice(self.actions)
 43         else:
 44             # take action according to the q function table
 45             next_state = self.possible_next_state(state)
 46             action = self.arg_max(next_state)
 47         return int(action)
 48 
 49     # compute arg_max if multiple candidates exit, pick one randomly
 50     @staticmethod
 51     def arg_max(next_state):
 52         max_index_list = []
 53         max_value = next_state[0]
 54         for index, value in enumerate(next_state):
 55             if value > max_value:
 56                 max_index_list.clear()
 57                 max_value = value
 58                 max_index_list.append(index)
 59             elif value == max_value:
 60                 max_index_list.append(index)
 61         return random.choice(max_index_list)
 62 
 63     # get the possible next states
 64     def possible_next_state(self, state):
 65         col, row = state
 66         next_state = [0.0] * 4 #四个方向,Q(s,a)
 67 
 68         if row != 0:
 69             next_state[0] = self.value_table[str([col, row - 1])]
 70         else:
 71             next_state[0] = self.value_table[str(state)]
 72 
 73         if row != self.height - 1:
 74             next_state[1] = self.value_table[str([col, row + 1])]
 75         else:
 76             next_state[1] = self.value_table[str(state)]
 77 
 78         if col != 0:
 79             next_state[2] = self.value_table[str([col - 1, row])]
 80         else:
 81             next_state[2] = self.value_table[str(state)]
 82 
 83         if col != self.width - 1:
 84             next_state[3] = self.value_table[str([col + 1, row])]
 85         else:
 86             next_state[3] = self.value_table[str(state)]
 87 
 88         return next_state
 89 
 90 
 91 # main loop
 92 if __name__ == "__main__":
 93     env = Env()
 94     agent = MCAgent(actions=list(range(env.n_actions)))
 95 
 96     for episode in range(1000):#episode task
 97         import pdb; pdb.set_trace()
 98         state = env.reset()
 99         action = agent.get_action(state)
100 
101         while True:
102             env.render()
103 
104             # forward to next state. reward is number and done is boolean
105             next_state, reward, done = env.step(action)
106             agent.save_sample(next_state, reward, done)
107 
108             # get next action
109             action = agent.get_action(next_state)
110 
111             # at the end of each episode, update the q function table
112             if done:
113                 print("episode : ", episode)
114                 agent.update()
115                 agent.samples.clear()
116                 break
原文地址:https://www.cnblogs.com/buyizhiyou/p/10250103.html