梯度下降算法

  之前在博客中写了三个语言版本的(matlab的有点小问题)。

  P.S.把代码中表达式的函数改为回调函数,这样可以自己定制方程,对于各种函数也就可以仅调用gradient_descent()方法来进行求解。

  下面是一个python的对函数 $ f(vec{x}) = 2 x_1^2 + x_2^2 $ 求梯度下降的实现:

square = lambda x: x ** 2
side = lambda a, b: (a ** 2 + b ** 2) ** (1 / 2)


def d(x_1, x_2):
    return -4 * x_1, -2 * x_2


def evaluate(x_1, x_2, a, b):
    return -(2 * a * x_1 + b * x_2) / (2 * square(a) + square(b))


def gradient_descent(x_1, x_2, epsilon, max_iters):
    iters = 0
    a, b = 1, 1
    while (side(a, b) > epsilon) & (iters < max_iters):
        print('%d (%.6f, %.6f)' % (iters + 1, x_1, x_2))
        a, b = d(x_1, x_2)
        gamma = evaluate(x_1, x_2, a, b)
        (x_1, x_2) = (x_1 + a * gamma, x_2 + b * gamma)
        iters += 1


if __name__ == '__main__':
    gradient_descent(1, 1, 0.001, 100)

  C/C++版( $ epsilon = 10^{-6} = 1e-6 $ ,lambda函数就是求参数λ,因为迭代公式:$ vec{x_{i+1}} = vec{x_i} + lambda_i cdot vec{d} $  ,这里$ vec d $为函数在 $ (x_{1_i}, x_{2_i}) $ 处的梯度值):

#include <iostream>
double square(double x)
{
	return x*x;
}
double part_a(double x)
{
	return 4 * (-x);
}
double part_b(double x)
{
	return 2 * (-x);
}
double hypotenuse(double a, double b)
{
	return sqrt(square(a) + square(b));
}
double lambda(double x_1, double x_2, double a, double b)
{
	return -(2 * a*x_1 + b*x_2) / (2 * square(a) + square(b));
}
void gradient_descent(
	double x_1,
	double x_2,
	double epsilon,
	unsigned max_iters)
{
	double a, b, λ;
	unsigned iters = 0;
	do
	{
		printf("(%f, %f)
", x_1, x_2);
		a = part_a(x_1), b = part_b(x_2);
		λ = lambda(x_1, x_2, a, b);
		x_1 += a*λ;
		x_2 += b*λ;
		iters++;
	} while (hypotenuse(a, b) > epsilon && iters < max_iters);
}
int main()
{
	double epsilon = 1e-6;
	double x_1 = 1., x_2 = 1.;
	gradient_descent(x_1, x_2, epsilon, 1000);
	getchar();
	return 0;
}

  output:

(1.000000, 1.000000)
(-0.111111, 0.444444)
(0.074074, 0.074074)
(-0.008230, 0.032922)
(0.005487, 0.005487)
(-0.000610, 0.002439)
(0.000406, 0.000406)
(-0.000045, 0.000181)
(0.000030, 0.000030)
(-0.000003, 0.000013)
(0.000002, 0.000002)
(-0.000000, 0.000001)
(0.000000, 0.000000)

  

原文地址:https://www.cnblogs.com/darkchii/p/9408542.html