阅读笔记——斜率优化

机器学习的内容过两天再写吧,昨晚心血来潮做了个算法题,写个题解凑篇博客。

题目来源:08年湖南NOI,在洛谷上可以找到这个题目,省选难度的题,可以说难度不小,链接:https://www.luogu.com.cn/problem/P3195

前两天刚写了有关增强学习的内容,提到了动态规划,我个人比较擅长做这类题。这道题显然也是动态规划解法,这个题简单的一个地方在于要求排列有序且不考虑容器个数,很显然一维dp就能到顶。状态dp[i]表示前i个玩具的总花费,其前一个状态dp[j]表示前j个玩具的总花费,剩下的第j+1到第i个玩具装到一个容器中,花费为(sum[i]+i-sum[j]-j-L-1)^2,sum[i]表示前i个玩具长度之和。很容易得到状态转移方程为dp[i] = min{dp[j]+ (sum[i]+i-sum[j]-j-L-1)^2},其中0<=j<i。编码的时候注意数据范围,我用的是Java,所以数据要用long来存储,直接使用这个状态转移方程,写两个for循环枚举j,i的所有情况,只通过了30%的数据,剩下的全部超时,作为一个省赛题目肯定不是让写两个for循环这么简单,一维dp却要用O(n^2)来解决,显然是有可以优化的空间的。

我们观察一下这个方程:dp[i] = min{dp[j]+ (sum[i]+i-sum[j]-j-L-1)^2},令C=L+1,方程简化为:dp[i] = min{dp[j]+ (sum[i]+i-sum[j]-j-C)^2},即存在一个j介于0到i-1之间,使dp[i] = dp[j]+ (sum[i]+i-sum[j]-j-C)^2,我们对该式进行展开变形,令s[t]=sum[t]+t,可以得到如下表达式,dp[i]+2*s[i]*(s[j]+C)=dp[j]+s[i]^2+(s[j]+C)^2,

令y= dp[j]+s[i]^2+(s[j]+C)^2,k=2*sum[i],x=s[j]+C,b=dp[i],这个方程就变为了一个直线系y=kx+b,在求解dp[i]的情况下,dp[0..i-1]都是已经求出来的,sum[0..i]都是已知量,在斜率k=2*sum[i]已经确定的情况下(这点要明确),直线通过点(sum[j]+C,dp[j]+sum[i]^2+(s[j]+C)^2),求最小截距dp[i]。此时使用单调队列进行维护,队头斜率小的出队,队尾斜率大的出队,并加入当前结点i,维持队列中头两个结点斜率最小的状态,O(n)内即可搞定。

参考代码:

import java.util.Scanner;

public class Main{
    public static long[]dp = new long[50001];
    public static long[]sum = new long[50001];
    public static long l;
    public static double a(int pos){
        return sum[pos];
    }
    public static double b(int pos){
        return dp[pos]+(sum[pos]+l)*(sum[pos]+l);
    }
    public static double ss(int pos1,int pos2){
        return (b(pos1)-b(pos2))/(a(pos1)-a(pos2));
    }

    public static void main(String[]args){


        int[]q = new int[50001];
        Scanner s = new Scanner(System.in);
        int n = s.nextInt(),head = 1,tail = 1;
        l = s.nextLong()+1;  //这里的L+1,后面直接-L即可
       
for(int i = 1;i<=n;++i){
            sum[i] = s.nextLong();
            sum[i] += sum[i-1];
        }
        for(int i = 1;i<=n;++i)sum[i]+=i;
        q[head] = 0;
        int j = 0;
        for(int i = 1;i<=n;++i){
            while(head<tail&&ss(q[head],q[head+1])<2*sum[i])head++;
            j = q[head];dp[i] = dp[j]+(sum[i]-sum[j]-l)*(sum[i]-sum[j]-l);
            while(head<tail&&ss(q[tail-1],q[tail])>ss(q[tail],i))tail--;
                q[++tail] = i;
        }
        System.out.println(dp[n]);
        return;
    }

}

使用如上代码,AC。

原文地址:https://www.cnblogs.com/messi2017/p/12271978.html