【LOJ】#2035. 「SDOI2016」征途

题解

有人管它叫带权二分,有人管它叫dp凸优化,有人管它叫wqs二分……

延伸出来还有zgl分治,xjp¥!%#!@#¥!#

当我没说

我们拆个式子,很容易发现所求的就是
(msum_{i = 1}^{m}s_{i}^2 - sum^{2})

然后去掉常数我们只要求(sum_{i = 1}^{m}s_{i}^2)的最小值

然而,我们需要m个?

我们发现,这个东西随着选的个数增多,越来越少,并且少得越来越慢(斜率变大,斜率是负的!)

我们二分最后一次的斜率,选一次减少q,最后能取到的最小值用了几天,如果恰好等于m那么就是我们需要的,这个dp显然就是用斜率优化一下就OK了

当然,我们很可能出现,m = 4,前一个二分到最小值天数为3,再 + 1就变成最小值天数为5了= =,事实上二分的边界再多一个左右斜率相等就行了

代码

#include <iostream>
#include <cstdio>
#include <vector>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <map>
//#define ivorysi
#define pb push_back
#define space putchar(' ')
#define enter putchar('
')
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define mo 974711
#define MAXN 3005
#define RG register
using namespace std;
typedef long long int64;
typedef double db;
template<class T>
void read(T &res) {
    res = 0;char c = getchar();T f = 1;
	    while(c < '0' || c > '9') {
			if(c == '-') f = -1;
			c = getchar();
	    }
	    while(c >= '0' && c <= '9') {
		res = res * 10 + c - '0';
		c = getchar();
    }
    res *= f;
}
template<class T>
void out(T x) {
    if(x < 0) {putchar('-');x = -x;}
    if(x >= 10) {
		out(x / 10);
    }
    putchar('0' + x % 10);
}
int N,M,que[MAXN],ql,qr,cnt[MAXN];
int64 S[MAXN],dp[MAXN],V,X[MAXN],Y[MAXN];
int calc(int a,int b) {
    return dp[a] + (S[b] - S[a]) * (S[b] - S[a]) - V; 
}
bool slope(int a,int b,int c) {
    return (X[c] - X[a]) * (Y[b] - Y[a]) - (X[b] - X[a]) * (Y[c] - Y[a]) >= 0;
}
int Check(int64 mid) {
    V = mid;
    ql = 1,qr = 1;
    cnt[0] = 0;dp[0] = 0;
    que[1] = 0;
    for(int i = 1 ; i <= N ; ++i) {
		while(ql < qr) {
		    if(calc(que[ql],i) > calc(que[ql + 1],i)) ++ql;
		    else break;
		}
		dp[i] = calc(que[ql],i);cnt[i] = cnt[que[ql]] + 1;
		X[i] = S[i];Y[i] = dp[i] + S[i] * S[i];
		while(ql < qr) {
		    if(slope(que[qr - 1],que[qr],i)) --qr;
		    else break;
		}
		que[++qr] = i;
    }
    return cnt[N];
}
void Solve() {
    read(N);read(M);
    for(int i = 1 ; i <= N ; ++i) read(S[i]);
    for(int i = 1 ; i <= N ; ++i) S[i] += S[i - 1];
    int64 L = -1000000000,R = 1000000000;
    while(L <= R) {
		int64 MID = (L + R + 1) >> 1;
		int x = Check(MID);
		if(x == M || L == R) {
		    int64 ans = M * (dp[N] + MID * M) - S[N] * S[N];
		    out(ans);enter;return;
		}
		if(x > M) R = MID - 1;
		else L = MID;
    }
}
int main() {
#ifdef ivorysi
    freopen("f1.in","r",stdin);
#endif
    Solve();
    return 0;
}
原文地址:https://www.cnblogs.com/ivorysi/p/9107597.html