两个变量(可支持多自变量)的简单梯度下降

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.axes3d import Axes3D

#   公式 f(x,y) = 2x^2+6y^2+6xy+x+4y+8
def targetFunc(x,y):
    return 2*(x**2)+6*y**2+6*x*y+x+4*y+8

#   偏导
#   f'x(x,y)=4x+6y+1
#   f'y(x,y)=12y+6x+4
def derivativeFunc(x,y):
    rx = 4*x+6*y+1
    ry = 12*y+6*x+4
    return (rx,ry)

pointList = []

def linerFunc(initPoint:tuple,targetFunc,derivativeFunc,step = 0.01,limitValue = 0.00000001,timeout=1000000,ax:Axes3D = None):
    count = 1
    initPoint = np.array(initPoint)
    ro,do = targetFunc(*initPoint),np.array(derivativeFunc(*initPoint))
    pointList.append((*initPoint, ro))

    newPoint = initPoint-do*step
    rn,dn = targetFunc(*newPoint),np.array(derivativeFunc(*newPoint))

    diff = np.abs(np.array(do-dn))

    while (diff > limitValue).any() and count < timeout:
        # print(initPoint)
        initPoint = newPoint
        ro, do = targetFunc(*initPoint), np.array(derivativeFunc(*initPoint))

        newPoint = np.where(np.abs(do*step) >= limitValue,initPoint-do*step,initPoint)
        rn, dn = targetFunc(*newPoint), np.array(derivativeFunc(*newPoint))
        diff = np.abs(np.array(do - dn))

        pointList.append((*initPoint, ro))
        count+=1
        pass
    print("最终运算次数为 : {0}".format(count))
    return rn,newPoint
    pass


if __name__=="__main__":
    x,y = np.linspace(-2,23,100),np.linspace(-2,23,100)
    x,y = np.meshgrid(x,y)
    fxy=targetFunc(x,y)

    fig = plt.figure()
    ax = Axes3D(fig)

    ax.plot_surface(x, y, fxy)
    limitValue,limitPoint = linerFunc((20,20),targetFunc,derivativeFunc,ax=ax)
    ax.scatter(*(np.array(pointList).T),c='r',s=20)
    print("该函数在({0},{1})处有驻点,值为{2}".format(limitPoint[0],limitPoint[1],limitValue))
    ax.legend()
    plt.show()
    pass

原文地址:https://www.cnblogs.com/dofstar/p/11462941.html