三分法求凸函数的极值

作者:jostree 转载请注明出处 http://www.cnblogs.com/jostree/p/4397990.html

在机器学习中,求凸函数的极值是一个常见的问题,常见的方法如梯度下降法,牛顿法等,今天我们介绍一种三分法来求一个凸函数的极值问题。

对于如下图的一个凸函数$f(x),xin [left,right]$,其中lm和rm分别为区间[left,right]的三等分点,我们发现如果f(lm)<f(rm),那么函数值最小的点的横坐标x一定在[left,rm]之间。如果x在[rm,right]之间,就会出现在rm左右都有比他低的点,这显然是不可能的。 同理,当f(lm)>f(rm)时,最值的横坐标x一定在[lm,right]的区间内。

利用这个性质,我们就可以在缩小区间的同时向目标点逼近,从而得到极值。


举一个例子,题目源自http://hihocoder.com/contest/hiho40/problem/1,如下图在直角坐标系中有一条抛物线y=ax^2+bx+c和一个点P(x,y),求点P到抛物线的最短距离d,其中-200≤a,b,c,x,y≤200。我们另pivot代表抛物线的对称抽,可以发现当X>pivot,我们可以取left = pivot,right = inf, 反之left = -inf , right = pivot, 其距离恰好满足凸形函数。而我们要求的最短距离d,正好就是这个凸形函数的极值。

代码如下:

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <limits.h>
#include <iostream>
#include <cmath>

using namespace std;
double a, b, c, x, y;
const double MAX = 100000;
double dis(double X)
{
    double Y = a*X*X+b*X+c;
    return sqrt((x-X)*(x-X)+(y-Y)*(y-Y));
}

double solve(double l, double r)
{
    double lm = l + (r-l)/3;
    double rm = r - (r-l)/3;
    double lmd = dis(lm);
    double rmd = dis(rm);
    if( fabs(lmd - rmd) < 0.0001 )
    {
        return lmd;
    }
    if( lmd > rmd )
    {
        return solve(lm, r);
    }
    else
    {
        return solve(l, rm);
    }
}

int main(int argc, char *argv[])
{
    while( cin>>a>>b>>c>>x>>y )
    {
        double pivot = -b/(2*a);
        double l = 0, r = 0;
        if( pivot < x )
        {
            l = pivot + 0.0001;
            r = MAX;
        }
        else
        {
            l = -MAX;
            r = pivot - 0.0001;
        }
        double res = solve(l, r);
        printf("%.3lf
", res);
    }
}
原文地址:https://www.cnblogs.com/jostree/p/4397990.html