算法笔记--斜率优化dp

斜率优化是单调队列优化的推广

用单调队列维护递增的斜率

参考:https://www.cnblogs.com/ka200812/archive/2012/08/03/2621345.html

以例1举例说明:

转移方程为:dp[i] = min(dp[j] + (sum[i] - sum[j])^2 + C)

假设k < j < i, 如果从j转移过来比从k转移过来更优

那么 dp[j] + (sum[i] - sum[j])^2 + C < dp[k] + (sum[i] - sum[k])^2 + C

dp[j] - dp[k] < (sum[i] - sum[k])^2 - (sum[i] - sum[j])^2

dp[j] - dp[k] < -2*sum[i]*sum[k] + sum[k]*sum[k] + 2*sum[i]*sum[j] - sum[j]*sum[j]

dp[j] - dp[k] + sum[j]*sum[j] - sum[k]*sum[k] < 2*sum[i]*(sum[j] - sum[k])

(dp[j] - dp[k] + sum[j]*sum[j] - sum[k]*sum[k]) / (sum[j] - sum[k]) < 2*sum[i]

我们观察不等式左边, 它是个斜率的形式, 自变量x为sum, 函数f(x)为dp + sum*sum

我们记这个斜率为g[j, k] = (dp[j] - dp[k] + sum[j]*sum[j] - sum[k]*sum[k]) / (sum[j] - sum[k])

说明1.如果g[j, k] < 2*sum[i] 表示对于dp[i], 从j转移过来比k更优, 反之k更优

说明2.下面我们来考虑着怎么从解集去掉多余的元素, 可以证明可能存在某些元素,无论怎样都不会是最优的,可以去掉这些多余的元素

假设k < j < i

结论:如果g[i, j] < g[j, k], 那么j可以去掉

证明:对于某个i, 如果g[i, j] < 2*sum[i], 那么i比j更优, 结论成立;

                         如果g[i, j] >= 2*sum[i], 那么g[j, k] > g[i, j] >= 2*sum[i], 那么k比j更优,结论成立. 

证毕.

所以如果把所有g[i, j] < g[j, k]的情况中(后面斜率比前面斜率小的情况)的j都去掉, 那么我们就得到相邻两个元素的斜率递增的状况

如下图

下面来说明怎么维护这个解集:

用双端队列维护这个解集, 每次从后面加入元素时, 按照说明2的方式去掉多余元素,使的相邻元素之间构成的斜率保持单调

每次从前面找答案, 由于斜率单调递增, 所以最后一个小于2*sum[i]就是最优的解, 因为这个位置之前的g[i, j]都小于2*sum,

表示后面的比前面更优, 之后的g[i, j] 都大于2*sum, 表示前面的比后面更优, 所以这个点是极值点

又因为sum[i]也具有单调性, 所以下一个极值点的位置肯定大于等于当前极值点, 所以当前极值点之前的都可以从双端队列中移出

ps:所有说明中, k < j < i

例题1:HDU - 3507

思路:维护递增斜率g[i, j] = (dp[i] - dp[j] + sum[i]*sum[i] - sum[j]*sum[j]) / (sum[i] - sum[j]) 

代码:

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<bits/stdc++.h>
using namespace std;
#define y1 y11
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb emplace_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pli pair<LL, int>
#define pii pair<int, int>
#define piii pair<pii, int>
#define pdd pair<double, double>
#define mem(a, b) memset(a, b, sizeof(a))
#define debug(x) cerr << #x << " = " << x << "
";
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//head 

const int N = 5e5 + 10;
int a[N], n, m;
LL sum[N], dp[N];
bool g(int k, int j, LL C) {
    return (dp[j]-dp[k]+sum[j]*sum[j]-sum[k]*sum[k]) <= C*(sum[j]-sum[k]);
}
bool gg(int k, int j, int i) {
    return (dp[i]-dp[j]+sum[i]*sum[i]-sum[j]*sum[j])*(sum[j]-sum[k]) <= (dp[j]-dp[k]+sum[j]*sum[j]-sum[k]*sum[k])*(sum[i]-sum[j]);
}
deque<int> q;
int main() {
    while(~scanf("%d %d", &n, &m)) {
        for (int i = 1; i <= n; ++i) scanf("%d", &a[i]), sum[i] = sum[i-1]+a[i];
        while(!q.empty()) q.pop_back();
        q.push_back(0);
        for (int i = 1; i <= n; ++i) {
            while(q.size() >= 2) {
                int a = q.front();
                q.pop_front();
                int b = q.front();
                if(g(a, b, 2*sum[i])) ;
                else {
                    q.push_front(a);
                    break;
                }
            }
            int j = q.front();
            dp[i] = dp[j] + (sum[i]-sum[j])*(sum[i]-sum[j])+m;
            while(q.size() >= 2) {
                int b = q.back();
                q.pop_back();
                int a = q.back();
                if(gg(a, b, i)) ;
                else {
                    q.push_back(b);
                    break;
                }
            }
            q.push_back(i);
        }
        printf("%lld
", dp[n]);
    }
    return 0;
}
View Code

例题2:HDU - 1300

思路:维护递增斜率g[i, j] = (dp[i] - dp[j]) / (sum[i] - sum[j]) 

代码:

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<bits/stdc++.h>
using namespace std;
#define y1 y11
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb emplace_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pli pair<LL, int>
#define pii pair<int, int>
#define piii pair<pii, int>
#define pdd pair<double, double>
#define mem(a, b) memset(a, b, sizeof(a))
#define debug(x) cerr << #x << " = " << x << "
";
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//head 

const int N = 100 + 10;
int a[N], p[N], n, m, T;
LL sum[N], dp[N];
bool g(int k, int j, LL C) {
    return (dp[j]-dp[k]) <= C*(sum[j]-sum[k]);
}
bool gg(int k, int j, int i) {
    return (dp[i]-dp[j])*(sum[j]-sum[k]) <= (dp[j]-dp[k])*(sum[i]-sum[j]);
}
deque<int> q;
int main() {
    scanf("%d", &T);
    while(T--) {
        scanf("%d", &n);
        for (int i = 1; i <= n; ++i) scanf("%d %d", &a[i], &p[i]), sum[i] = sum[i-1]+a[i];
        for (int i = n-1; i >= 1; --i) p[i] = min(p[i], p[i+1]);
        while(!q.empty()) q.pop_back();
        q.push_back(0);
        for (int i = 1; i <= n; ++i) {
            while(q.size() >= 2) {
                int a = q.front();
                q.pop_front();
                int b = q.front();
                if(g(a, b, p[i])) ;
                else {
                    q.push_front(a);
                    break;
                }
            }
            int j = q.front();
            dp[i] = dp[j] + (sum[i]-sum[j]+10)*p[i];
            while(q.size() >= 2) {
                int b = q.back();
                q.pop_back();
                int a = q.back();
                if(gg(a, b, i)) ;
                else {
                    q.push_back(b);
                    break;
                }
            }
            q.push_back(i);
        }
        printf("%lld
", dp[n]);
    }
    return 0;
}
View Code

例题3:HDU - 2993

思路:论文题,维护递增的斜率,居然卡读入,没意思

代码:

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<bits/stdc++.h>
using namespace std;
#define y1 y11
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb emplace_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pli pair<LL, int>
#define pii pair<int, int>
#define piii pair<pii, int>
#define pdd pair<double, double>
#define mem(a, b) memset(a, b, sizeof(a))
#define debug(x) cerr << #x << " = " << x << "
";
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//head 

const int N = 1e5 + 10;
int n, k, a[N], q[N], head, tail;
double sum[N];
const int BUF = 25000000;
char Buf[BUF],*buf=Buf;
inline void read(int &a)
{
    for(a=0;*buf<48;buf++);
    while(*buf>47) a=a*10+*buf++-48;
}
int main() {
    int tot = fread(Buf, 1, BUF, stdin);
    while(true) {
        if(buf-Buf+1 >= tot) break;
        read(n), read(k);
        for (int i = 1; i <= n; ++i) read(a[i]), sum[i] = sum[i-1]+a[i];
        head = tail = 0;
        q[tail++] = 0;
        double ans = 0;
        for (int i = k; i <= n; ++i) {
            while(head+1 < tail) {
                int a = q[head];
                head++;
                int b = q[head];
                if((sum[i]-sum[a])*(i-b) < (sum[i]-sum[b])*(i-a)) ;
                else {
                    q[--head] = a;
                    break;
                }
            }
            int x = q[head];
            ans = max(ans, (sum[i]-sum[x])/(i-x));
            x = i-k+1;
            while(head+1 < tail) {
                int b = q[tail-1];
                --tail;
                int a = q[tail-1];
                if((sum[x]-sum[b])*(x-a) < (sum[x]-sum[a])*(x-b));
                else {
                    q[tail++] = b;
                    break;
                }
            }
            q[tail++] = x;
        }
        printf("%.2f
", ans);
    }
    return 0;
}
View Code

例题4:UVALive - 5097

思路:去重后发现按宽度排序后,高度递减

那么维护递增斜率:g[j, k] = (dp[j] - dp[k]) / (h[k] - h[j])

代码:

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<bits/stdc++.h>
using namespace std;
#define y1 y11
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb emplace_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pli pair<LL, int>
#define pii pair<int, int>
#define piii pair<pii, int>
#define pdd pair<double, double>
#define mem(a, b) memset(a, b, sizeof(a))
#define debug(x) cerr << #x << " = " << x << "
";
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//head 

const int N = 5e4 + 10;
pii a[N];
vector<pii> vc;
int n, k, h[N], w[N];
LL dp[105][N];
deque<int> q[105];
bool g(int id, int k, int j, LL C) {
    return (dp[id][j]-dp[id][k]) <= C*(h[k+1]-h[j+1]);
}
bool gg(int id, int k, int j, int i) {
    return (dp[id][i]-dp[id][j])*(h[k+1]-h[j+1]) <= (dp[id][j]-dp[id][k])*(h[j+1]-h[i+1]);
}
int main() {
    while(~scanf("%d %d", &n, &k)) {
        for (int i = 1; i <= n; ++i) scanf("%d %d", &a[i].fi, &a[i].se);
        sort(a+1, a+1+n);
        vc.clear();
        for (int i = n; i >= 1; --i) if(i == n || a[i].se > vc.back().se) vc.pb(a[i]);
        reverse(vc.begin(), vc.end());
        n = vc.size();
        for (int i = 0; i < n; ++i) w[i+1] = vc[i].fi, h[i+1] = vc[i].se;
        for (int i = 0; i <= k; ++i) while(!q[i].empty()) q[i].pop_back();
        q[0].push_back(0);
        for (int i = 0; i <= k; ++i) for (int j = 0; j <= n; ++j) dp[i][j] = 0x3f3f3f3f3f3f3f3f;
        dp[0][0] = 0;
        for (int i = 1; i <= n; ++i) {
            for (int j = 0; j < k; ++j) {
                while(q[j].size() >= 2) {
                    int a = q[j].front();
                    q[j].pop_front();
                    int b = q[j].front();
                    if(g(j, a, b, w[i])) ;
                    else {
                        q[j].push_front(a);
                        break;
                    }
                }
                int x = q[j].front();
                dp[j+1][i] = min(dp[j+1][i], dp[j][x] + w[i]*1LL*h[x+1]);
                while(q[j].size() >= 2) {
                    int b = q[j].back();
                    q[j].pop_back();
                    int a = q[j].back();
                    if(gg(j, a, b, i)) ;
                    else {
                        q[j].push_back(b);
                        break;
                    }
                }
                q[j].push_back(i);
            }
        }
        LL ans = 1LL<<60;
        for (int i = 1; i <= k; ++i) ans = min(ans, dp[i][n]);
        printf("%lld
", ans);
    }
    return 0;
}
View Code

例题5:HDU - 3045

思路:维护递增斜率:g[j, k] = (dp[j]-dp[k]+sum[k]-sum[j]+a[j+1]*j-a[k+1]*k) / (a[j+1]-a[k+1])

代码:

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<bits/stdc++.h>
using namespace std;
#define y1 y11
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb emplace_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pli pair<LL, int>
#define pii pair<int, int>
#define piii pair<pii, int>
#define pdd pair<double, double>
#define mem(a, b) memset(a, b, sizeof(a))
#define debug(x) cerr << #x << " = " << x << "
";
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//head 

const int N = 4e5 + 5;
int n, k;
LL a[N], sum[N], dp[N];
bool g(int k, int j, LL C) {
    return dp[j]-dp[k]+sum[k]-sum[j]+a[j+1]*j-a[k+1]*k <= C*(a[j+1]-a[k+1]);
}
bool gg(int k, int j, int i) {
    return (dp[i]-dp[j]+sum[j]-sum[i]+a[i+1]*i-a[j+1]*j)*(a[j+1]-a[k+1]) <= (dp[j]-dp[k]+sum[k]-sum[j]+a[j+1]*j-a[k+1]*k)*(a[i+1]-a[j+1]);
}
deque<int> q;
int main() {
    while(~scanf("%d %d", &n, &k)) {
        for (int i = 1; i <= n; ++i) scanf("%lld", &a[i]);
        sort(a+1, a+1+n);
        for (int i = 1; i <= n; ++i) sum[i] = sum[i-1]+a[i];
        while(!q.empty()) q.pop_back();
        dp[0] = 0;
        q.push_back(0);
        for (int i = k; i <= n; ++i) {
            while(q.size() >= 2) {
                int a = q.front();
                q.pop_front();
                int b = q.front();
                if(g(a, b, i)) ;
                else {
                    q.push_front(a);
                    break;
                }
            }
            int j = q.front();
            dp[i] = dp[j]+sum[i]-sum[j]-a[j+1]*1LL*(i-j);
            if(i-k+1 >= k) {
                while(q.size() >= 2) {
                    int b = q.back();
                    q.pop_back();
                    int a = q.back();
                    if(gg(a, b, i-k+1)) ;
                    else {
                        q.push_back(b);
                        break;
                    }
                }
                q.push_back(i-k+1);
            }
        }
        printf("%lld
", dp[n]);
    }
    return 0;
}
View Code

例题6:POJ - 1180

思路:要单独算s的影响,因为有s的存在时间就不好算前缀和了,对于每次新的开始s的影响是s*suf[i]

那么就是维护递增斜率:g[j, k] = (dp[j]-dp[k]+s*(suf[j+1]-suf[k+1]) / (sum[j] - sum[k])

代码:

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<iostream>
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<deque>
using namespace std;
#define y1 y11
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb emplace_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pli pair<LL, int>
#define pii pair<int, int>
#define piii pair<pii, int>
#define pdd pair<double, double>
#define mem(a, b) memset(a, b, sizeof(a))
#define debug(x) cerr << #x << " = " << x << "
";
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//head 

const int N = 1e4 + 5;
int T[N], F[N], n, s;
LL sum[N], suf[N], dp[N];
bool g(int k, int j, LL C) {
    return dp[j]-dp[k]+s*(suf[j+1]-suf[k+1]) <= C*(sum[j]-sum[k]);
}
bool gg(int k, int j, int i) {
    return (dp[i]-dp[j]+s*(suf[i+1]-suf[j+1]))*(sum[j]-sum[k]) <= (dp[j]-dp[k]+s*(suf[j+1]-suf[k+1]))*(sum[i]-sum[j]);
}
deque<int> q;
int main() {
    scanf("%d", &n);
    scanf("%d", &s);
    for (int i = 1; i <= n; ++i) scanf("%d %d", &T[i], &F[i]);
    for (int i = 1; i <= n; ++i) sum[i] = sum[i-1] + F[i], T[i]+=T[i-1];
    for (int i = n; i >= 1; --i) suf[i] = suf[i+1] + F[i];
    q.push_back(0);
    for (int i = 1; i <= n; ++i) {
        while(q.size() >= 2) {
            int a = q.front();
            q.pop_front();
            int b = q.front();
            if(g(a, b, T[i])) ;
            else {
                q.push_front(a);
                break;
            }
        }
        int j = q.front();
        dp[i] = dp[j] + T[i]*(sum[i]-sum[j])+s*suf[j+1];
        while(q.size() >= 2) {
            int b = q.back();
            q.pop_back();
            int a = q.back();
            if(gg(a, b, i)) ;
            else {
                q.push_back(b);
                break;
            }
        }
        q.push_back(i);
    }
    printf("%lld
", dp[n]);
    return 0;
}
View Code

例题7:POJ - 2018

思路:同HDU-2993

代码:

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<iostream>
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<deque>
using namespace std;
#define y1 y11
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb emplace_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pli pair<LL, int>
#define pii pair<int, int>
#define piii pair<pii, int>
#define pdd pair<double, double>
#define mem(a, b) memset(a, b, sizeof(a))
#define debug(x) cerr << #x << " = " << x << "
";
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//head 

const int N = 1e5 + 10;
int n, f, a[N];
LL sum[N];
deque<int> q;
bool g(int k, int j, int i) {
    return (sum[j]-sum[k])*(i-j) <= (sum[i]-sum[j])*(j-k);
}
int main() {
    scanf("%d %d", &n, &f);
    for (int i = 1; i <= n; ++i) scanf("%d", &a[i]), sum[i]=sum[i-1]+a[i];
    q.push_back(0);
    LL ans = 0;
    for (int i = f; i <= n; ++i) {
        while(q.size() >= 2) {
            int a = q.front();
            q.pop_front();
            int b = q.front();
            if(g(a, b, i)) ;
            else {
                q.push_front(a);
                break;
            }
        }
        int x = q.front();
        ans = max(ans, (sum[i]-sum[x])*1000/(i-x));
        x = i+1-f;
        while(q.size() >= 2) {
            int b = q.back();
            q.pop_back();
            int a = q.back();
            if(!g(a, b, x)) ;
            else {
                q.push_back(b);
                break;
            }
        }
        q.push_back(x);
    }
    printf("%lld
", ans);
    return 0;
}
View Code

例题8:POJ - 3709

思路:维护递增斜率:g[j, k] = (dp[j]-dp[k]+sum[k]-sum[j]+a[j+1]*j-a[k+1]*k) / (a[j+1]-a[k+1])

代码:

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<iostream>
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<deque>
using namespace std;
#define y1 y11
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb emplace_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pli pair<LL, int>
#define pii pair<int, int>
#define piii pair<pii, int>
#define pdd pair<double, double>
#define mem(a, b) memset(a, b, sizeof(a))
#define debug(x) cerr << #x << " = " << x << "
";
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//head 

const int N = 5e5 + 10;
int a[N], n, k, T;
LL sum[N], dp[N];
LL dw(int k, int j) {
    return a[j+1]-a[k+1];
}
LL up(int k, int j) {
    return dp[j]-dp[k]+sum[k]-sum[j]+a[j+1]*1LL*j-a[k+1]*1LL*k;
}
LL g(int k, int j, LL C) {
    return up(k, j) <= C*dw(k, j);
}
LL gg(int k, int j, int i) {
    return up(j, i)*dw(k, j) <= up(k, j)*dw(j, i);
}
deque<int> q;
int main() {
    scanf("%d", &T);
    while(T--) {
        scanf("%d %d", &n, &k);
        for (int i = 1; i <= n; ++i) scanf("%d", &a[i]), sum[i]=sum[i-1]+a[i];
        while(!q.empty()) q.pop_back();
        q.push_back(0);
        for (int i = k; i <= n; ++i) {
            while(q.size() >= 2) {
                int a = q.front();
                q.pop_front();
                int b = q.front();
                if(g(a, b, i));
                else {
                    q.push_front(a);
                    break;
                }
            }
            int x = q.front();
            dp[i] = dp[x]+sum[i]-sum[x]-a[x+1]*1LL*(i-x);
            x = i-k+1;
            if(x >= k) {
                while(q.size() >= 2) {
                    int b = q.back();
                    q.pop_back();
                    int a = q.back();
                    if(gg(a, b, x)) ;
                    else {
                        q.push_back(b);
                        break;
                    }
                }
                q.push_back(x);
            }
        }
        printf("%lld
", dp[n]);
    }
    return 0;
}
View Code

例题9:UVA - 12594

思路:维护递增斜率:g[j, k] = (dp[j]-dp[k]+sum[k]-sum[j]-k*s[k]+j*s[j]) / (j-k),其中sum[i] = ∑(j-pos)*pos, s[i] = ∑pos

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<bits/stdc++.h>
using namespace std;
#define y1 y11
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb emplace_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pli pair<LL, int>
#define pii pair<int, int>
#define piii pair<pii, int>
#define pdd pair<double, double>
#define mem(a, b) memset(a, b, sizeof(a))
#define debug(x) cerr << #x << " = " << x << "
";
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//head

const int N = 2e4 + 10, M = 505;
const LL INF = 0x3f3f3f3f3f3f3f3f;
int T, n, k, pos[26];
LL sum[N], s[N], dp[M][N];
char nm[N], pn[N];
deque<int> q[M];
LL up(int id, int k, int j) {
    return dp[id][j]-dp[id][k]+sum[k]-sum[j]-k*s[k]+j*s[j];
}
LL dw(int k, int j) {
    return j-k;
}
bool g(int id, int k, int j, LL C) {
    return up(id, k, j) <= C*dw(k, j);
}
bool gg(int id, int k, int j, int i) {
    return up(id, j, i)*dw(k, j) <= up(id, k, j)*dw(j, i);
}
int main() {
    scanf("%d", &T);
    for(int cs = 1; cs <= T; ++cs) {
        scanf("%s %d", pn, &k);
        scanf("%s", nm+1);
        n = strlen(nm+1);
        for (int i = 0; i < 26; ++i) pos[pn[i]-'a'] = i;
        for (int i = 1; i <= n; ++i) s[i] = s[i-1]+pos[nm[i]-'a'];
        for (int i = 1; i <= n; ++i) sum[i] = sum[i-1]+(i-1-pos[nm[i]-'a'])*1LL*pos[nm[i]-'a'];
        for (int i = 0; i <= k; ++i) while(!q[i].empty()) q[i].pop_back();
        dp[0][0] = 0;
        q[0].push_back(0);
        for (int i = 1; i <= n; ++i) {
            for (int j = 0; j < k; ++j) {
                while(q[j].size() >= 2) {
                    int a = q[j].front();
                    q[j].pop_front();
                    int b = q[j].front();
                    if(g(j, a, b, s[i])) ;
                    else {
                        q[j].push_front(a);
                        break;
                    }
                }
                int x = q[j].front();
                dp[j+1][i] = dp[j][x]+sum[i]-sum[x]-x*(s[i]-s[x]);
            }
            for (int j = 1; j <= k; ++j) {
                while(q[j].size() >= 2) {
                    int b = q[j].back();
                    q[j].pop_back();
                    int a = q[j].back();
                    if(gg(j, a, b, i)) ;
                    else {
                        q[j].push_back(b);
                        break;
                    }
                }
                q[j].push_back(i);
            }
        }
        printf("Case %d: %lld
", cs, dp[k][n]);
    }
    return 0;
}
View Code

 

原文地址:https://www.cnblogs.com/widsom/p/9323394.html