石子合并(四边形不等式优化)

题目大意很简单,和普通的石子合并过程没有区别,只是花费变成了一个多项式,若连续的任意个石子权值和为x,那么代价变为F(x) = sigma(a[i] * x^i),求将n堆石子合并为一队的最小花费。

对于暴力的做法,复杂度是O(n^3)的,所以要优化

我们知道当a, b, c, d(a <= b < c <= d)当有cost[a][c] + cost[b][d] <= cost[a][d] + cost[b][c] 时,我们称其满足四边形不等式,设p[i][j]表示当区间[i, j]取最优决策时所选择的下标,这时可以证明有p[i][j - 1] <= p[i][j] <= p[i + 1][j](花了我好长时间终于证明了),没事了可以证明下看看,也可以记住这个结论。

这时当按区间dp时,计算区间[i, j]的最优解,只要枚举[p[i][j - 1], p[i + 1][j]]即可,由于数组p取值为[1, n]且是单调的,所以枚举的总复杂度为O(n),最后加上区间枚举的复杂度,总复杂度为O(n^2)

所以对于一般性的题目,需要证明的只有dp量是不是满足四边形不等式的,对于这道题就是要证明:

设sum(a, b) = x, sum(b, c) = z, sum(c, d) = y;

有 F(x + z) + F(y + z) <= F(z) + F(x + y + z),即证明:

sigma(a[i] * ( (x + z)^i + (y + z)^i - z^i - (x+y+z)^i )) <= 0,转化为证明:

(x + z) ^ n  +  (y + z) ^ n  -  z ^ n  -  (x + y + z) ^ n <= 0恒成立。

很明显这个不等式可以利用数学归纳法加以简单的证明。

 1 #include <map>
 2 #include <set>
 3 #include <stack>
 4 #include <queue>
 5 #include <cmath>
 6 #include <ctime>
 7 #include <vector>
 8 #include <cstdio>
 9 #include <cctype>
10 #include <cstring>
11 #include <cstdlib>
12 #include <iostream>
13 #include <algorithm>
14 using namespace std;
15 #define INF 0x3f3f3f3f
16 #define inf (-((LL)1<<40))
17 #define lson k<<1, L, (L + R)>>1
18 #define rson k<<1|1,  ((L + R)>>1) + 1, R
19 #define mem0(a) memset(a,0,sizeof(a))
20 #define mem1(a) memset(a,-1,sizeof(a))
21 #define mem(a, b) memset(a, b, sizeof(a))
22 #define FIN freopen("in.txt", "r", stdin)
23 #define FOUT freopen("out.txt", "w", stdout)
24 #define rep(i, a, b) for(int i = a; i <= b; i ++)
25 
26 template<class T> T CMP_MIN(T a, T b) { return a < b; }
27 template<class T> T CMP_MAX(T a, T b) { return a > b; }
28 template<class T> T MAX(T a, T b) { return a > b ? a : b; }
29 template<class T> T MIN(T a, T b) { return a < b ? a : b; }
30 template<class T> T GCD(T a, T b) { return b ? GCD(b, a%b) : a; }
31 template<class T> T LCM(T a, T b) { return a / GCD(a,b) * b;    }
32 
33 //typedef __int64 LL;
34 typedef long long LL;
35 const int MAXN = 51000;
36 const int MAXM = 110000;
37 const double eps = 1e-4;
38 //LL MOD = 987654321;
39 
40 int T, n, m, s[1100], a[10];
41 LL p[1100][1100], dp[1100][1100];
42 
43 LL fun(int x) {
44     LL ans = 0, p = 1;
45     rep (i, 0, m) {
46         ans += a[i] * p;
47         p *= x;
48     }
49     return ans;
50 }
51 
52 int main()
53 {
54     //FIN;
55     while(~scanf("%d", &T)) while(T--) {
56         scanf("%d", &n);
57         rep (i, 1, n) scanf("%d", s + i), s[i] += s[i - 1];
58         scanf("%d", &m);
59         rep (i, 0, m) scanf("%d", a + i);
60         mem0(dp); mem0(p);
61         rep (len, 1, n) {
62             rep (i, 1, n - len + 1) {
63                 int j = i + len - 1;
64                 LL cost = fun(s[j] - s[i - 1]);
65                 if(len <= 1) { dp[i][j] = 0; p[i][j] = i; }
66                 else rep (k, p[i][j - 1], p[i + 1][j]) {
67                     if(dp[i][k] + dp[k+1][j] + cost < dp[i][j] || dp[i][j] == 0) {
68                         p[i][j] = k;
69                         dp[i][j] = dp[i][k] + dp[k+1][j] + cost;
70                     }
71                 }
72             }
73         }
74         cout << dp[1][n] << endl;
75     }
76     return 0;
77 }
原文地址:https://www.cnblogs.com/gj-Acit/p/4493512.html