最小乘积生成树的另类做法

最小乘积生成树是最小生成树的变形,每条边有一个权值$(a_i, b_i)$, 我们要求一棵生成树,使得$sum{a_i} cdot sum{b_i}$最小。

网上大多数做法是,把解空间看做二维平面上的点,$sum{a_i}$ $sum{b_i}$分别看做点的横纵坐标。显然最优解一定是在解集构成的下凸壳上。

这里需要用到另外一种求凸壳的方法。先确定最左边的点$A$,再确定最右边的点$B$,然后找到离直线$BA$最远的点$C$也就是满足$overrightarrow{BA} imes overrightarrow{BC}$最大的点$C$,这个点一定在凸壳上,然后递归$AC$, $CB$部分的凸壳。 似乎可以证明下凸壳上的点是$O(M^2)$级别的,总的复杂度是$O(M^3logM)$的。

最近做Petrozavodsk Winter-2014. Moscow SU Tapir Contest的A题的时候学到了另外一种姿势,理论复杂度是一样的,但是实现起来更加方便,在此分享一下。

原题是给定$N$个点,每个点有一个权值$(a_i, b_i)$,选取恰好$K$个点使得$sum{a_i} cdot sum{b_i}$最小。显然这个题可以套用最小乘积生成树求下凸壳的做法,时间复杂度是$O(N^3logN)$ 。 用之后介绍的算法同样可以做到$O(N^3logN)$,而且可以通过一些手段优化做到$O(N^2logN)$。

这个做法可以扩展回最小乘积生成树,不过遗憾的是,扩展回最小乘积生成树问题不能通过本题采用的手段降低复杂度,还是只能做到$O(M^3logM)$。

参考知乎的回答

首先是一个非常重要的转化:一定存在某个常数$lambda(lambda geq 0)$,使得$a_i + lambda b_i$前K小的点集与原问题的最优解相同。(对应到最小乘积生成树问题中,即将边权按$a_i + lambda b_i$排序做kruskal算法得到的最小生成树就是最小乘积生成树, 以下内容将考虑取$K$个点的问题,请自行对应到最小乘积生成树问题)。

证明: 假设最优解的$sum{a_i} = A, sum{b_i} = B$, 取$lambda = frac{B}{A}$即可。

取$lambda = frac{B}{A}$. 则$a_i + lambda b_i$前K小的点集的$sum{a_i} = A', sum{b_i} = B'$.

则$A' + lambda B' leq  A + lambda B$

若$A' + lambda B' lt A + lambda B$

则$A' cdot lambda B' leq (frac{A' + lambda B'}{2})^2 lt (frac{A+lambda B}{2})^2 = lambda AB $    即 $ A'B' < AB$,于是产生了矛盾。

因此一定有$A' + lambda B' =  A + lambda B$.  所以最优解一定对应$sum{a_i + lambda b_i}$最小的某个解。

而我们的目的是$sum{a_i} cdot sum{b_i}$最小,因此$sum{a_i + lambda b_i}$最小还不够,还要使得$sum{a_i}$最小或者$sum{b_i}$最小.

因此,我们只要枚举所有的$lambda$,然后按$a_i + lambda b_i$为第一关键字,$a_i$为第二关键字,排序做一遍,再按$a_i + lambda b_i$为第一关键字,$b_i$为第二关键字做一遍,一定可以得到最优解。

虽然$lambda$的取值有无穷多种,但实际上我们只要考虑$a_i + lambda b_i$的排序结果不同的$lambda$。 即使得$a_i + lambda b_i = a_j + lambda b_j$的这些$lambda$。 再进一步分析,其实我们甚至不需要对每个$lambda$分别以$a_i$和$b_i$为第二关键字排序做两遍。假设使得$a_i + lambda b_i = a_j + lambda b_j$的这些$lambda$分别是$lambda_1, lambda_2, cdots, lambda_k$, 我们只需要分别在区间$(0, lambda_1), (lambda_1, lambda_2) cdots (lambda_{k-1}, lambda_k), (lambda_k, +inf)$各取一个$lambda$,仅仅按照$a_i + lambda b_i$排序做一遍即可。

因此对于本题总共只要做$O(N^2)$遍排序,总的复杂度$O(N^3logN)$;对于最小乘积生成树问题,只要做$O(M^2)$遍普通最小生成树即可。

对于本题,只要枚举$lambda$然后排序取前$K$小。如果按从小到大的顺序枚举$lambda$,每次改变的时候我们可以知道哪些元素的排名上升了,哪些下降了,因此可以直接维护,不必每次重新排序。具体只要每次将那些排名会发生变化的元素重新拿出来排序,然后插回去。可以想象成,数轴上有$N$辆车,起始位置分别是$a_i$,速度分别是$b_i$,而$lambda$可以看做时间,随着$lambda$的变大,维护车的排名。  复杂度可以降低到$O(N^2logN)$.

然而,最小乘积生成树问题并不能这样维护,因为kruskal算法不是单纯的取最小的几条边。

本题代码:

  1 #include <bits/stdc++.h>
  2 using namespace std;
  3 
  4 typedef long long LL;
  5 #define N 1010
  6 int a[N], b[N], rk[N], now[N], n, K;
  7 double val[N];
  8 LL ans = 1e18;
  9 vector<int> best;
 10 const double EPS = 1e-6;
 11 
 12 
 13 struct event
 14 {
 15     double t;
 16     int k1, k2;
 17     bool operator < (const event &o)const
 18     {
 19         return t < o.t;
 20     }
 21 };
 22 vector<event> L;
 23 
 24 
 25 bool cmp(int x, int y) {return val[x] < val[y];}
 26 
 27 int main()
 28 {
 29     //freopen("in.txt", "r", stdin);
 30  
 31     scanf("%d %d", &n, &K);
 32     for (int i = 1; i <= n; ++i)
 33         scanf("%d %d", &a[i], &b[i]), now[i] = i, val[i] = a[i];
 34     sort(now + 1, now + n + 1, cmp);
 35     for (int i = 1; i <= n; ++i) rk[now[i]] = i;
 36     
 37     
 38     LL sa = 0, sb = 0;
 39     for (int i = 1; i <= K; ++i) 
 40     {
 41         sa += a[now[i]], sb += b[now[i]];
 42          best.push_back(now[i]);
 43     }
 44     ans = sa * sb;
 45     
 46     
 47     for (int i = 1; i < n; ++i)
 48     {
 49         for (int j = i + 1; j <= n; ++j)
 50         {
 51             if (b[i] == b[j]) continue;
 52             //a[i] + k*b[i] = a[j] + k * b[j]  
 53             if (1LL * (a[j] - a[i]) * (b[i] - b[j]) >= 0) 
 54                 L.push_back((event){(double)(a[j] - a[i]) / (b[i] - b[j]), i, j});
 55         }
 56     }
 57     sort(L.begin(), L.end()); 
 58     
 59     for (int i = 0, j; i < L.size(); i = j + 1)
 60     {
 61         vector<int> lis;
 62         vector<int> pos;
 63         
 64         j = i;
 65         while (j + 1 < L.size() && fabs(L[j + 1].t - L[i].t) < EPS) ++j;
 66         for (int k = i; k <= j; ++k)
 67         {
 68             lis.push_back(L[k].k1); 
 69             lis.push_back(L[k].k2);
 70         }
 71         sort(lis.begin(), lis.end());
 72         lis.erase(unique(lis.begin(), lis.end()), lis.end());
 73                                        
 74         double lamda = j + 1 == L.size()? L[i].t + 1: (L[j].t + L[j + 1].t) / 2.0;
 75         for (auto x: lis) val[x] = a[x] + lamda * b[x], pos.push_back(rk[x]);
 76         sort(lis.begin(), lis.end(), cmp);
 77         sort(pos.begin(), pos.end());
 78         for (int k = 0; k < lis.size(); ++k)
 79         {
 80             int x = lis[k];
 81             //cout <<  "!! " << x << " " << rk[x] << " " << val[x] << endl; 
 82             if (rk[x] <= K) sa -= a[x], sb -= b[x];
 83             if (pos[k] <= K) sa += a[x], sb += b[x];
 84             rk[x] = pos[k];
 85             now[pos[k]] = x;
 86         }
 87         assert(sa > 0 && sb > 0);
 88         if (sa * sb < ans)
 89         {
 90                ans = sa * sb;
 91                best.clear();
 92                for (int k = 1; k <= K; ++k)
 93                    best.push_back(now[k]);
 94         }
 95     }
 96 
 97     sa = sb = 0;
 98     printf("%lld
", ans);
 99     sort(best.begin(), best.end());
100     for (int i = 0; i < best.size(); ++i)
101         printf("%d%c", best[i], i + 1 == best.size()? '
': ' '), sa += a[best[i]], sb += b[best[i]];
102     assert(sa * sb == ans);
103     return 0;
104 }
原文地址:https://www.cnblogs.com/vb4896/p/9931514.html