Sarsa模型和Q_learning模型简记

1. Sarsa模型

1.1 Sarsa类代码:

class SarsaAgent(object):
    def __init__(self,state_n,action_n,learning_rate=0.01,gamma=0.9,e_greed=0.1):
        """
        :param state_n:状态总数
        :param action_n:动作总数
        :param learning_rate:学习速率
        :param gamma:奖励衰减率
        :param e_greed:随机选择动作的概率,智能体有0.1的概率,在当前状态下随机选择动作action
        """
        self.act_n=action_n
        self.lr=learning_rate
        self.gamme=gamma
        self.epsilon=e_greed
        # 建立Q表,一共有state_n行,acton_n列
        self.Q = np.zeros((state_n,action_n))

1.2 sample函数:

        # 根据观察值,输出动作值
    def sample(self,state):
        # 如果讲武德,按照常理,应该从Q表中根据当前state选择action值比较大的
        if np.random.uniform(0,1)<(1-self.epsilon):
            action = self.predict(state)
        # 但也有0.1的概率,随机从action表中选取一个
        else:
            action = np.random.choice(self.act_n)
        return action

    def predict(self,obs):
        Q_list = self.Q[obs,]
        maxQ=np.max(Q_list) # 取Q表中当前状态下的最大action值
        action_list=np.where(Q_list==maxQ)[0]   # 然后把所有与最大action值相同的action变量,都取出来
        action = np.random.choice(action_list)  # 随机从这些action变量里取一个
        return action

Q值更新的公式:

[egin{align} Target=G_t & =R_{t+1}+gamma R_{t+2}+gamma^2R_{t+3}+{ldots} =sum_{k=0}^ngamma^kR_{t+k+1} \ G_t & =R_{t+1}+gamma R_{t+2}+gamma^2R_{t+3}+{ldots} \ & = R_{t+1}+gamma (R_{t+2}+gamma R_{t+3}+{ldots}) \ & = R_{t+1}+gamma G_{t+1} end{align} ]

'Sarsa模型':用下一个状态的Q值,来更新当前状态的Q值,也就是用G(t+1)来更新G(t)。在状态St下,需要知道的有At(当前状态下选择的动作),Rt(当前状态下选择动作后的回报),S(t+1)(下一个状态),A(t+1)(下一个状态选择的动作),然后根据这个五元组(St,At,Rt,S(t+1),A(t+1))来更新当前状态下的Q值。并且到了S(t+1)后一定会执行A(t+1)

1.3 Q值更新函数

def sarsa_learn(self,state,action,reward,next_state,next_action,done):
    """
    :param self:
    :param state:当前状态
    :param action: 当前状态下选择的动作
    :param reward: 当前状态下选择动作的回报
    :param next_state: 选择动作后的下一个状态
    :param next_action: 下一个状态下选择的动作
    :param done: 是否到达目的地,到达目的地后就奖励reward
    :return:
    """

    #  predict_Q :预测值,也就是当前状态下,选择动作后的回报
    #  target_Q : 目标值,
    predict_Q = self.Q[state,action]
    if done:
        target_Q = reward
    else:
        target_Q = reward + self.gamma*self.Q[next_state,next_action]
    # 当前状态下Q值更新,时序差分,不懂时序差分的可以百度了解一下
    self.Q[state,action]+=self.lr*(target_Q-predict_Q)

2. Q_learning模型

Q值更新的公式:

[Q(S_t,A_t)=Q(S_t,A_t)+alpha[R_{t+1}+gamma{max_a}{Q}(S_{t+1},a)-Q(S_t,A_t)] ]

Q_learning模型:不需要知道下一个状态选择的是那个动作,根据下一个状态S(t+1),求得Q值最大的action,然后利用最大的action来更新当前状态St的Q值,也就是会默认用下一个状态的Q值最大的动作来更新当前状态Q值。但是到了状态S(t+1)后,不一定执行动作action,因为还会有一个随机的概率来随机选择动作

Q_learning 的代码除更新的公式那里不一样,其余基本都一样,Sarsa需要计算下一个状态下的action,Q_learning需要计算下一个状态下的最大的Q值(不管是那个动作)。

2.1 Q值更新函数如下:

def sarsa_learn(self,state,action,reward,next_state,next_action,done):
    """
    :param self:
    :param state:当前状态
    :param action: 当前状态下选择的动作
    :param reward: 当前状态下选择动作的回报
    :param next_state: 选择动作后的下一个状态
    :param next_action: 下一个状态下选择的动作
    :param done: 是否到达目的地,到达目的地后就奖励reward
    :return:
    """

    #  predict_Q :预测值,也就是当前状态下,选择动作后的回报
    #  target_Q : 目标值,
    predict_Q = self.Q[state,action]
    if done:
        target_Q = reward
    else:
        target_Q = reward + self.gamma*np.max(self.Q[next_state,:])
    # 当前状态下Q值更新,时序差分,不懂时序差分的可以百度了解一下
    self.Q[state,action]+=self.lr*(target_Q-predict_Q)
原文地址:https://www.cnblogs.com/52dxer/p/14006778.html