bzoj4518 [Sdoi2016]征途(斜率优化dp)

分析:
斜率优化dp
很多人做斜率优化的时候喜欢画出斜率
我偏向画柿子

题目就是把若干个元素分成m份,
每一份的价值是该组中的元素之和
使得m组数的方差最小

平均数:x=(sum[1]+sum[2]+..+sum[m])/m //sum是每一组的价值
x=(Σa[i])/m
方差:s=((sum[1]-x)^2+(sum[2]-x)^2+…+(sum[m]-x)^2)/m

先想状态转移方程:
设 f[i][j]表示到第i个点,分成了j段
f[i][k]=f[j][k-1]+(sum[i]-sum[j]-x)^2
sum是前缀和

这样的话就简单了,这就是一个斜率优化的模板

我们已经说过了:
x=(Σa[i])/m
s=((sum[1]-x)^2+(sum[2]-x)^2+…+(sum[m]-x)^2)/m
最后答案是s*m^2

带入s的表达式得:
这里写图片描述

ans=Σ(sum[i]-sum[j])^2*m-sum[n]^2

只要把状态转移方程f[i][k]=f[j][k-1]+(sum[i]-sum[j]-x)^2
变成f[i][k]=f[j][k-1]+(sum[i]-sum[j])^2 即可

之后就是画柿子

斜率优化

我们假设k < j < i
如果j的决策比k的决策要好
则有
这里写图片描述

左边那一大坨是一个斜率的形式,
我们可以用ta来优化了

设g[j][k]=(那一大坨式子)

g[j][k] < sum[i]
表示j比k更优

现在关键来了
设k < j < i
如果g[i][j] < g[j][k],那么j点永远不可能成为最优转移点

解释一下
我们假设g[i][j] < sum[i],那么也就是说i要比j优,排除j

若g[i][j]>=sum[i] 也就是说j比i要优
但是g[i][k]>g[i][j],说明k比j还要优

接下来看看怎么找最优解
设k < j < i
我们排除了g[i][j] < g[j][k]
则整个有效点集呈现一种上凸性质,即k <=> j的斜率要大于j <=> i的斜率

做法可以总结如下:
1.用一个单调队列来维护解集
2.假设队列中从头到尾已经有元素a,b,c
那么当d要入队的时候,我们维护队列的上凸性质,
即如果g[d][c] < g[c][b],那么就将c点删除
直到找到g[d][x]>=g[x][y]为止,并将d点加入在该位置中
3.求解时候,从队头开始,如果已有元素a,b,c,
当i点要求解时,如果g[b][a] < sum[i],那么说明b点比a点更优,a点可以排除,于是a出队
直到g[x][y]>=sum[i]
当前点就从y转移

最终答案:

min(f[n])*m-sum[n]^2

tip

我为什么要做斜率优化!!!
我讨厌式子

这是一个二维的方程,所以我们需要另一个数组记录上一层的状态,辅助dp
在计算斜率的转移的时候我们都要用上一层的状态

我和学姐对式子的处理方式不一样,
我超虚的,然而我一A了!!!

这里写代码片
#include<cstdio>
#include<iostream>
#include<cstring>
#include<cmath>
#define ll long long

using namespace std;

const ll INF=1e16;
int n,m; 
ll a[3010];
ll f[3010],g[3010],q[3010],tou,wei,x=0;

ll sqr(ll x)
{
    return x*x;
}

double get(int j,int k)
{
    return (double)(g[j]+sqr(a[j])-g[k]-sqr(a[k]))/(double)(2*(a[j]-a[k]));
}

void doit()
{
    int i,j;
    ll ans=INF;
    tou=wei=1;
    for (int i=1;i<=n;i++) g[i]=INF;   //记录上一层状态 
    g[0]=0;
    for (i=1;i<=m;i++)
    {
        tou=wei=0;
        for (j=1;j<=n;j++)
        {
            while (tou<wei&&get(q[tou+1],q[tou])<a[j]) tou++;
            f[j]=g[q[tou]]+sqr(a[j]-a[q[tou]]);
            while (tou<wei&&get(j,q[wei])<get(q[wei],q[wei-1])) wei--;
            q[++wei]=j;
        }
        ans=min(ans,f[n]);
        for (int j=1;j<=n;j++) g[j]=f[j];
    }
    printf("%lld",m*ans-sqr(a[n]));
}

int main()
{
    scanf("%d%d",&n,&m);
    for (int i=1;i<=n;i++) scanf("%lld",&a[i]),a[i]+=a[i-1];
    doit();
    return 0;
}
原文地址:https://www.cnblogs.com/wutongtong3117/p/7673160.html