(六)Value Function Approximation-LSPI code (4)

本篇是solver.py

  1 # -*- coding: utf-8 -*-
  2 """Contains main LSPI method and various LSTDQ solvers."""
  3 
  4 import abc
  5 import logging
  6 
  7 import numpy as np
  8 
  9 import scipy.linalg
 10 
 11 
 12 class Solver(object):#这里也出现一个继承ABC类的类了
 13 
 14     r"""ABC for LSPI solvers.
 15 
 16     Implementations of this class will implement the various LSTDQ algorithms
 17     with various linear algebra solving techniques. This solver will be used
 18     by the lspi.learn method. The instance will be called iteratively until
 19     the convergence parameters are satisified.
 20 
 21     """
 22 
 23     __metaclass__ = abc.ABCMeta#继承
 24 
 25     @abc.abstractmethod#必须覆盖的函数
 26     def solve(self, data, policy):#求解函数
 27         r"""Return one-step update of the policy weights for the given data.
 28             #该函数对于给出的数据更新一步权重
 29         Parameters#输入参数
 30         ----------
 31         data:#数据
 32             #求解器需要的数据,通常是一个元素是采样的列表,当然也可以是各种求解器支持的方法
 33             This is the data used by the solver. In most cases this will be
 34             a list of samples. But it can be anything supported by the specific
 35             Solver implementation's solve method.
 36         policy: Policy#策略
 37             当前的策略,要对它进行提升
 38             The current policy to find an improvement to.
 39 
 40         Returns
 41         -------
 42         numpy.array#输出的权重
 43             Return the new weights as determined by this method.
 44 
 45         """
 46         pass  # pragma: no cover
 47 
 48 
 49 class LSTDQSolver(Solver):#最小二乘TDQ求解器
 50 
 51     """LSTDQ Implementation with standard matrix solvers.
 52     #用矩阵的形式实现
 53     #算法根据文献的第五张图,如果矩阵A是满秩的,那么就用scipy的库来计算
 54     #如果不满秩,就用最小二乘的方法
 55     Uses the algorithm from Figure 5 of the LSPI paper. If the A matrix
 56     turns out to be full rank then scipy's standard linalg solver is used. If
 57     the matrix turns out to be less than full rank then least squares method
 58     will be used.
 59     #通常矩阵A的对角线值是小的正数值,这用来保证即使是很少的采样,矩阵A也能满秩,如果
 60     #不想要这样的前提,可以让前提条件值为0
 61     By default the A matrix will have its diagonal preconditioned with a small
 62     positive value. This will help to ensure that even with few samples the
 63     A matrix will be full rank. If you do not want the A matrix to be
 64     preconditioned then you can set this value to 0.
 65 
 66     Parameters前提条件值
 67     ----------
 68     precondition_value: float
 69         Value to set A matrix diagonals to. Should be a small positive number.
 70         If you do not want preconditioning enabled then set it 0.
 71     """
 72 
 73     def __init__(self, precondition_value=.1):#初始化
 74         """Initialize LSTDQSolver."""
 75         self.precondition_value = precondition_value#对前提条件值赋值
 76 
 77     def solve(self, data, policy):#求解函数
 78         """Run LSTDQ iteration.
 79 
 80         See Figure 5 of the LSPI paper for more information.
 81         """
 82         k = policy.basis.size()#k是特征phi向量的长度
 83         a_mat = np.zeros((k, k))#建立A矩阵,k行k列
 84         np.fill_diagonal(a_mat, self.precondition_value)#向矩阵A中填充前提条件值
 85         #说明前提条件值是用来保证矩阵是满秩的
 86 
 87         b_vec = np.zeros((k, 1))#b向量
 88 
 89         for sample in data:#对于data中的每一个采样进行循环
 90             phi_sa = (policy.basis.evaluate(sample.state, sample.action)
 91                       .reshape((-1, 1)))#通过basisfunction求出phi值
 92 
 93             if not sample.absorb:
 94                 best_action = policy.best_action(sample.next_state)#计算下一个状态下的最佳动作
 95                 phi_sprime = (policy.basis
 96                               .evaluate(sample.next_state, best_action)
 97                               .reshape((-1, 1)))#计算一个新的phi
 98             else:
 99                 phi_sprime = np.zeros((k, 1))
100 
101             a_mat += phi_sa.dot((phi_sa - policy.discount*phi_sprime).T)#计算a矩阵
102             b_vec += phi_sa*sample.reward#计算b矩阵
103 
104         a_rank = np.linalg.matrix_rank(a_mat)
105         if a_rank == k:#如果满秩
106             w = scipy.linalg.solve(a_mat, b_vec)#求逆解出w值
107         else:
108             logging.warning('A matrix is not full rank. %d < %d', a_rank, k)
109             w = scipy.linalg.lstsq(a_mat, b_vec)[0]
110         return w.reshape((-1, ))#返回已经优化后的w值.
原文地址:https://www.cnblogs.com/lijiajun/p/5490041.html