斜率优化DP

现在是晚上零点三十分,我来写这篇文章,总结一下我今天学到的简单的斜率优化。

什么是斜率优化,就是将递推式写为y=kx+b的形式。

假设原递推式长这样:f[i]=min{f[j]+C},其中C可能是一个关于i的函数,一个关于j的函数,一个关于i和j的函数。

前两种情况可以通过单调队列来解决,但是情况三中无法分离i和j,只能使用斜率优化。

y=kx+b,其中y=f(j),k=f(i),x=f(j),b=f(i)+const,其中f(x)代表与x有关的一个函数。

光说是空的,我们来具体分析。

例题一:HDU3507 Print Article

网址:http://acm.hdu.edu.cn/showproblem.php?pid=3507

不难得到递推式f[i]=min{f[j]+(sum[i]-sum[j])^2+m};

去掉min函数并展开:f[i]=f[j]+sum[i]^2-2*sum[i]*sum[j]+sum[j]^2+m

将含i的项与含j的项分离,并把单纯含j的项写在左边:f[j]+sum[j]^2=2*sum[i]*sum[j]+f[i]-sum[i]^2-m

现在把f[j]+sum[j]^2看做y,2*sum[i]看做k,sum[j]看做x,f[i]-sum[i]^2-m看做b。

对于每一个点(sum[j],f[j]+sum[j]^2)都是固定的,对于每一个i,斜率2*sum[i]也是固定的,而截距也就是b。

b越小,f[i]就越小。我们可以形象的理解为:一条斜率为2*sum[i]的线从下往上扫,扫到的第一个点就是答案。

我们构造一个队列,维护相邻两点间的斜率。如果能保证斜率单调递增,那么第一个满足斜率大于2*sum[i]的就是结果。

由于2*sum[i]也是单调递增的,所以不大于它的可以出队。时间复杂度N。

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int maxn=500000+10;
int n,m,x,s[maxn],f[maxn],q[maxn];
int yval(int x,int y){
    return f[y]-f[x]+s[y]*s[y]-s[x]*s[x];
}//计算y坐标的差
int xval(int x,int y){
    return s[y]-s[x];
}//计算x坐标的差
signed main(){
    while(cin>>n>>m){
        for(int i=1;i<=n;i++){
            scanf("%lld",&x);
            s[i]=s[i-1]+x;
        }
        memset(f,0x3f,sizeof(f));
        memset(q,0,sizeof(q));
        int l=1,r=1;
        q[l]=0,f[0]=0;
        for(int i=1;i<=n;i++){
            while(l<r&&yval(q[l],q[l+1])<=xval(q[l],q[l+1])*2*s[i])l++;
       //当队首不满足斜率大于当前斜率,则出队 f[i]
=f[q[l]]+(s[i]-s[q[l]])*(s[i]-s[q[l]])+m;
       //计算f[i]
while(l<r&&yval(q[r-1],q[r])*xval(q[r],i)>=xval(q[r-1],q[r])*yval(q[r],i))r--;
       //若队尾不满足单调性,则出队 q[
++r]=i; } printf("%lld ",f[n]); } return 0; }

例题2:P3195 [HNOI2008]玩具装箱TOY /【模板】斜率优化

网址:https://www.luogu.com.cn/problem/P3195
 
不难推出递推方程:dp[i]=min{dp[j]+(sum[i]+isum[j]jL1)^2}
 
不妨设A=i+sum[i],B=j+sum[j]+1。则递推式可以转化为:dp[i]=min{dp[j]+(A-B+L)^2}
 
接下来的部分留给读者自行推倒(主要是我懒得写了QAQ)
 
看代码:
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int maxn=100000;
int n,L,f[maxn],q[maxn];
int c[maxn],sum[maxn];
int a(int i){
    return i+sum[i];
}
int b(int i){
    return i+sum[i]+1;
}
int x(int i,int j){
    return b(j)-b(i);
}
int y(int i,int j){
    return f[j]+b(j)*b(j)-f[i]-b(i)*b(i);
}
signed main(){
    cin>>n>>L;
    for(int i=1;i<=n;i++){
        scanf("%lld",&c[i]);
        sum[i]=sum[i-1]+c[i];
    }
    memset(f,0x3f,sizeof(f));
    int l=1,r=1;
    q[l]=0;f[0]=0;
    for(int i=1;i<=n;i++){
        while(l<r&&y(q[l],q[l+1])<=x(q[l],q[l+1])*2*(a(i)-L))l++;
        f[i]=f[q[l]]+(a(i)-b(q[l])-L)*(a(i)-b(q[l])-L);
        while(l<r&&y(q[r-1],q[r])*x(q[r],i)>=x(q[r-1],q[r])*y(q[r],i))r--;
        q[++r]=i;
    }
    cout<<f[n]<<endl;
    return 0;
}

例题3:P2120 [ZJOI2007]仓库建设

网址:https://www.luogu.com.cn/problem/P2120

容易想到,我们必须在最低点也就是第n个点建设一个仓库。由此想到这样设状态:f[i]表示仅处理1~i的工厂的最小花费。

找到一个j<i,在j处建工厂,从j+1到i-1的都运到i处,那么推出状转方程:f[i]=min{f[j]+c[i]+p[k]*(x[i]-x[k])},k从i+1到j-1

预处理两个前缀和,设sp为p的前缀和,spx为p*x的前缀和。

可以推出简化版的N^2的状转方程:f[i]=min{f[j]+c[i]+x[i]*(sp[i-1]-sp[j])-spx[i-1]+spx[j]}

写成一次函数的形式:f[j]+spx[j]=x[i]*sp[j]+f[i]-c[i]-x[i]*sp[i-1]+spx[i-1],其中y=f[j]+spx[j],k=x[i],x=sp[j],b=f[i]-c[i]-x[i]*sp[i-1]+spx[i-1]

数据保证了x[i]单调递增,非常好!!!

看代码:

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int maxn=1000000+10;
int c[maxn],x[maxn],p[maxn],q[maxn];
int sp[maxn],spx[maxn],f[maxn],n;
inline int yval(int a,int b){
    return f[b]-f[a]+spx[b]-spx[a];
}
inline int xval(int a,int b){
    return sp[b]-sp[a];
}
signed main(){
    cin>>n;
    for(int i=1;i<=n;i++){
        scanf("%lld%lld%lld",&x[i],&p[i],&c[i]);
        sp[i]=sp[i-1]+p[i];
        spx[i]=spx[i-1]+p[i]*x[i];
    }
    memset(f,0x3f,sizeof(f));
    f[0]=0;
    int l=1,r=1;
    q[l]=0;
    for(int i=1;i<=n;i++){
        while(l<r&&yval(q[l],q[l+1])<=xval(q[l],q[l+1])*x[i])l++;
        f[i]=f[q[l]]+c[i]+x[i]*(sp[i-1]-sp[q[l]])-spx[i-1]+spx[q[l]];
        while(l<r&&yval(q[r-1],q[r])*xval(q[r],i)>=xval(q[r-1],q[r])*yval(q[r],i))r--;
        q[++r]=i;
    }
    printf("%lld
",f[n]);
    return 0;
}

例题4:P3628 [APIO2010]特别行动队

网址:https://www.luogu.com.cn/problem/P3628

这题不像前面那样板子了,至少我不认为它是个板子。

推出的方程长这样子:f[j]+a*sum[j]^2-b*sum[j]=(2*a*sum[i])*sum[j]+f[i]-a*sum[i]^2-b*sum[i]-c

看一下数据范围,a恒为负,这时斜率2*a*sum[i]单调递减,同时我们要求截距的最大值。

此时维护上凸壳(斜率单调递减)

看代码:

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int maxn=1000000+10;
int n,a,b,c,sum[maxn],f[maxn],q[maxn];
int yval(int x,int y){
    return f[y]+a*sum[y]*sum[y]-b*sum[y]-f[x]-a*sum[x]*sum[x]+b*sum[x];
}
int xval(int x,int y){
    return sum[y]-sum[x];
}
signed main(){
    cin>>n>>a>>b>>c;
    int tmp;
    for(int i=1;i<=n;i++){
        scanf("%lld",&tmp);
        sum[i]=sum[i-1]+tmp;
        f[i]=-1e12;
    }
    int l=1,r=1;
    q[l]=0;
    for(int i=1;i<=n;i++){
        while(l<r&&yval(q[l],q[l+1])>=xval(q[l],q[l+1])*2*a*sum[i])l++;
        f[i]=f[q[l]]+a*(sum[i]-sum[q[l]])*(sum[i]-sum[q[l]])+b*(sum[i]-sum[q[l]])+c;
        while(l<r&&yval(q[r-1],q[r])*xval(q[r],i)<=xval(q[r-1],q[r])*yval(q[r],i))r--;
        q[++r]=i;
    }
    printf("%lld
",f[n]);
    return 0;
}

例题5:P4360 [CEOI2004]锯木厂选址

网址:https://www.luogu.com.cn/problem/P4360

这是我斜率DP第一个没有一遍AC的,原因是第一遍忘开long long了。

这一题比较特殊,细心的同学一定发现了,递推式不带f。

为了方便,设d数组的后缀和为sd[i]=sd[i+1]+d[i],设k数组的前缀和为sk[i]=sk[i-1]+k[i](k[i]即是题目中的w[i])

设f[i]为第二个锯木厂选在i时的最小值,假设第一个锯木厂在j,从1~j-1运到j的和是k[p]*(sd[p]-sd[j]),p∈[1,n],从j+1~i-1运到i的和是k[p]*(sd[p]-sd[i]),p∈[j+1,i]。

从i+1~n运到第三个锯木厂的和是k[p]*sd[p],p∈[i+1,n]。设k[p]*sd[p]的和为sum。

那么整理一下此式为:f[i]=sum-sd[j]*sk[j]-sd[i]*sd[k]+sd[i]*sk[j],惊奇的发现不需要递推,可惜没有什么用。

整理成一次函数式:sd[j]*sk[j]=sd[i]*sk[j]+sum-f[i]-sd[i]*sk[i]

要让截距最大,且斜率sd[i]单调递减,那么考虑维护上凸包(不明白的一定要自己画图尝试!!)

看代码:

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int maxn=50000;
int n,d[maxn],sd[maxn],sum,sk[maxn],k[maxn];
int q[maxn],f[maxn];
int yval(int a,int b){return sd[b]*sk[b]-sd[a]*sk[a];}
int xval(int a,int b){return sk[b]-sk[a];}
signed main(){
    cin>>n;
    for(int i=1;i<=n;i++){
        scanf("%lld%lld",&k[i],&d[i]);
        sk[i]=sk[i-1]+k[i];
    }
    for(int i=n;i>=1;i--)
        sd[i]=sd[i+1]+d[i];
    for(int i=1;i<=n;i++)
        sum+=k[i]*sd[i];
    int l=1,r=1,ans=2147483647;
    for(int i=1;i<=n;i++){
        while(l<r&&yval(q[l],q[l+1])>=xval(q[l],q[l+1])*sd[i])l++;
        f[i]=sd[i]*sk[q[l]]-sd[q[l]]*sk[q[l]]+sum-sd[i]*sk[i];
        while(l<r&&yval(q[r-1],q[r])*xval(q[r],i)<=xval(q[r-1],q[r])*yval(q[r],i))r--;
        q[++r]=i;
        ans=min(ans,f[i]);
    }
    printf("%lld
",ans);
    return 0;
}

例题6:P4072 [SDOI2016]征途

网址:https://www.luogu.com.cn/problem/P4072

虽然又是一遍AC的,但不得不说这题打得我好慌张,还调了十几分钟,虽然都是智障错误。

回归正题。猛然一看,n只有3000,貌似不要斜率优化耶。

但事实上,普通DP是N^2*M的,所以还是得斜率优化哈哈。

构造数列A1~Am,表示第i段的和。

方差乘m^2后长这样子:m*Ai^2-(Ai)^2

惊喜的发现(Ai)^2是个定值,那么我们只要使Ai^2最小即可。

就问你眼不眼熟?这不就是例一吗?但是此题多一维。

可以推导出状转方程:f[i][j]=min{f[k][j-1]+(sum[i]-sum[k])^2}

发现f[……][i]只跟f[……][i-1]有关,那就队列里放i-1的然后统计i的答案呗。

看代码:

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int maxn=3000+10;
int n,m,f[maxn][maxn],sum[maxn],d[maxn],q[maxn];
int yval(int a,int b,int c){
    return f[b][c]+sum[b]*sum[b]-f[a][c]-sum[a]*sum[a];
}
int xval(int a,int b){return sum[b]-sum[a];}
signed main(){
    cin>>n>>m;
    for(int i=1;i<=n;i++){
        scanf("%lld",&d[i]);
        sum[i]=sum[i-1]+d[i];
    }
    //memset(f,0x3f,sizeof(f));
    for(int i=1;i<=n;i++)
        f[i][1]=sum[i]*sum[i];
    for(int i=2;i<=m;i++){
        int l=1,r=1;
        q[l]=0;
        for(int j=1;j<=n;j++){
            while(l<r&&yval(q[l],q[l+1],i-1)<=xval(q[l],q[l+1])*2*sum[j])l++;
            f[j][i]=f[q[l]][i-1]+(sum[j]-sum[q[l]])*(sum[j]-sum[q[l]]);
            //printf("%d %d %d %d %d
",j,i,q[l],f[q[l]][i-1],f[j][i]);
            while(l<r&&yval(q[r-1],q[r],i-1)*xval(q[r],j)>=xval(q[r-1],q[r])*yval(q[r],j,i-1))r--;
            q[++r]=j;
            //printf("%d %d
",l,r);
        }
    }
    //for(int i=1;i<=m;i++)
    //    for(int j=1;j<=n;j++)
    //        printf("f[%d][%d]=%d
",j,i,f[j][i]);
    printf("%lld
",m*f[n][m]-sum[n]*sum[n]);    
    return 0;
} 

看到那一堆调试代码了吗?警醒读者:别把i和j弄反了!!!

例题7:P5785 [SDOI2012]任务安排

网址:https://www.luogu.com.cn/problem/P5785

这道题不太一样了。通过费用提前可以推倒递推式长这样:f[i]=min{f[j]+sumt[i]*(sumc[i]-sumc[j])+s*(sumc[n]-sumc[j])}

写成一次函数形式长这样:f[j]=(s+sumt[i])*sumc[j]+f[i]-sumt[i]*sumc[i]-s*sumc[n]

但是我们发现由于t不在保证是正数,sumt[i]也没有单调性,那么就只好二分求答案了。

看代码:

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int maxn=500000;
int n,s,t[maxn],c[maxn],f[maxn];
int st[maxn],sc[maxn],q[maxn];
int yval(int a,int b){return f[b]-f[a];}
int xval(int a,int b){return sc[b]-sc[a];}
signed main(){
    cin>>n>>s;
    for(int i=1;i<=n;i++){
        scanf("%lld%lld",&t[i],&c[i]);
        st[i]=st[i-1]+t[i];
        sc[i]=sc[i-1]+c[i];
    }
    memset(f,0x3f,sizeof(f));
    int l=1,r=1;
    q[l]=0;f[0]=0;
    for(int i=1;i<=n;i++){
        int x=1,y=r;
        if(l!=r)
            while(x<y){
                int mid=(x+y)>>1;
                if(yval(q[mid],q[mid+1])<=(s+st[i])*xval(q[mid],q[mid+1]))x=mid+1;
                else y=mid;
            }
        f[i]=f[q[x]]+st[i]*(sc[i]-sc[q[x]])+s*(sc[n]-sc[q[x]]);
        while(l<r&&yval(q[r-1],q[r])*xval(q[r],i)>=xval(q[r-1],q[r])*yval(q[r],i))r--;
        q[++r]=i;
    }
    printf("%lld
",f[n]);
    return 0;
}

 例题8:P2900 [USACO08MAR]土地征用Land Acquisition

网址:https://www.luogu.com.cn/problem/P2900

这题有些不一样,睁大眼睛看题,发现不用连续地取,那么我们就可以预处理一下。

把h从大到小排个序,然后从前往后扫一遍,如果当前的这片土地的w值不比前面的最大值大,那么他就可以被包含,无贡献。

这时我们取出了一个h递减,w递增的数列,这时取就必须连续了。

递推式长这样:f[i]=f[j]+b[i].w*b[j+1].h

写成一次函数式:f[j]=-b[i].w*b[j+1].h+f[i]

由于斜率递减,维护上凸包。

看代码:

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int maxn=1e5;
struct data{
    int h,w;
}a[maxn],b[maxn];
int cmp(data x,data y){
    if(x.h==y.h)return x.w>y.w;
    return x.h>y.h;
}
int n,tot,mxw,q[maxn],f[maxn];
signed main(){
    cin>>n;
    for(int i=1;i<=n;i++)
        scanf("%lld%lld",&a[i].h,&a[i].w);
    sort(a+1,a+1+n,cmp);
    for(int i=1;i<=n;i++)
        if(a[i].w>mxw){
            mxw=a[i].w;
            b[++tot]=a[i];
        }
    int l=1,r=1;
    q[l]=0;
    for(int i=1;i<=tot;i++){
        while(l<r&&(f[q[l]]-f[q[l+1]])>=-b[i].w*(b[q[l]+1].h-b[q[l+1]+1].h))l++;
        f[i]=f[q[l]]+b[i].w*b[q[l]+1].h;
        while(l<r&&(f[q[r-1]]-f[q[r]])*(b[q[r]+1].h-b[i+1].h)<=(b[q[r-1]+1].h-b[q[r]+1].h)*(f[q[r]]-f[i]))r--;
        q[++r]=i;
    }
    printf("%lld
",f[tot]);
    return 0;
}
原文地址:https://www.cnblogs.com/syzf2222/p/12285944.html