Codeforces 311B Cats Transport 斜率优化dp

Cats Transport

出发时间居然能是负的,我服了。。。 卡了我十几次, 我一直以为斜率优化写搓了。

我们能得出dp方程式 dp[ i ][ j ] = min(dp[ k ][ j - 1 ] + hs[ i ] * (cnt[ i ] - cnt[ j ]) - sum[ i ] + sum[ j ]) k < i

这个东西显然能斜率优化, 直接搞。

其实不用离散化直接dp更好写。

#include<bits/stdc++.h>
#define LL long long
#define fi first
#define se second
#define mk make_pair
#define PLL pair<LL, LL>
#define PLI pair<LL, int>
#define PII pair<int, int>
#define SZ(x) ((int)x.size())
#define ull unsigned long long
using namespace std;

const int N = 1e5 + 7;
const int inf = 0x3f3f3f3f;
const LL INF = 0x3f3f3f3f3f3f3f3f;
const int mod = 1e9 + 7;
const double eps = 1e-8;

int n, m, p, d[N], hs[N], h[N], t[N], tot;
LL dp[2][N], cnt[N], sum[N];
int que[2][N], head[2], rear[2];

int getId(int x) {
    return lower_bound(hs + 1, hs + tot + 1, x) - hs;
}

double calc(int j, int k, int t) {
    return 1.0 * (dp[t][j] + sum[j] - dp[t][k] - sum[k]) / (cnt[j] - cnt[k]);
}

int main() {
    scanf("%d%d%d", &n, &m, &p);
    for(int i = 2; i <= n; i++) scanf("%d", &d[i]);
    for(int i = 2; i <= n; i++) d[i] += d[i - 1];
    LL mn = INF;
    for(int i = 1; i <= m; i++) {
        scanf("%d%d", &h[i], &t[i]);
        t[i] -= d[h[i]];
        mn = min(mn, 1ll*t[i]);
    }
    for(int i = 1; i <= m; i++) {
        t[i] -= mn;
        hs[++tot] = t[i];
    }
    sort(hs + 1, hs + tot + 1);
    tot = unique(hs + 1, hs + tot + 1) - hs - 1;
    for(int i = 1; i <= m; i++) cnt[getId(t[i])]++;
    for(int i = 1; i <= tot; i++) sum[i] = sum[i - 1] + cnt[i] * hs[i];
    for(int i = 1; i <= tot; i++) cnt[i] += cnt[i - 1];
    int cur = 0, lst = 1;
    for(int j = 1; j <= p; j++) {
        swap(cur, lst);
        memset(dp[cur], INF, sizeof(dp[cur]));
        head[cur] = 1;
        rear[cur] = 0;
        for(int i = j; i <= tot; i++) {
            if(j == 1) {
                dp[cur][i] = hs[i] * cnt[i] - sum[i];
            }
            else {
                while(rear[lst]-head[lst]+1 >= 2
                      && calc(que[lst][head[lst]+1], que[lst][head[lst]], lst) < hs[i]) head[lst]++;
                int who = que[lst][head[lst]];
                dp[cur][i] = dp[lst][who] + hs[i] * (cnt[i] - cnt[who]) - sum[i] + sum[who];
            }
            while(rear[cur]-head[cur]+1 >= 2 &&
                  calc(que[cur][rear[cur]], que[cur][rear[cur]-1], cur) >
                  calc(i, que[cur][rear[cur]], cur)) rear[cur]--;
            que[cur][++rear[cur]] = i;
        }
    }
    LL ret = INF;
    if(p >= tot) ret = 0;
    else ret = dp[cur][tot];
    printf("%lld
", ret);
    return 0;
}

/*
*/

  

原文地址:https://www.cnblogs.com/CJLHY/p/10409602.html