[loj2478]林克卡特树

原题等价于选择恰好$k+1$条不相交(无公共点)的路径使得边权和最大
证明:对于原题中的最优解,一定包含了k条0边权的边(否则可以将未使用的边删掉,然后将这条路径的末尾与不在同一个连通块内的点连边),那么选择这k条0边权的边所划分的$k+1$条路径即可;对于这$k+1$条路径,将每一条路径首尾连0边权的边,由于这些0边权的边和选择的边无法构成环,因此一定可以删除k条为选择的非0边使其变成一棵树,即原题中的操作
然后令$f(k)$表示选择了恰好k条路径的答案,那么有对于$forall 1le i<n$,都有$2f(i)ge f(i-1)+f(i+1)$,即$f(i)-f(i-1)ge f(i+1)-f(i)$
证明:建立一张费用流的图:S->i(1,0);i->i'(1,0);i'->T(1,0);i'->j(1,v(i,j))。容易发现$f(x)= 流量为x的最大费用$,由于费用流存在凸性,所以f也存在凸性
根据凸性二分即可,即二分$f(i)-f(i-1)ge k$,考虑判定:将每条路径权值减去k并选择任意条路径使得权值和最大,那么最后即求出了$f(i)-ki$(特殊情况:$f(k+1)-f(k)=……=f(k+i)-f(k+i-1)$,那么只可以找到$f(k+i)$和$f(k)$,根据等式求出$f(k+1)$即可)
具体的树形dp:用$f[i][j=0/1/2]$表示以i为根的子树选择的端点包含i的边数j,转移分类讨论即可(注意:根据二分的过程,我们要选择尽量多的路径,因此还要记录对应的路径数量,可以用结构体来转移) 
 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 #define N 300005
 4 #define oo 1e12
 5 #define ll long long
 6 #define pli pair<ll,int>
 7 #define fi first
 8 #define se second
 9 #define mx(k) max(f[k][0],max(f[k][1],f[k][2]))
10 int E,n,m,k,x,y,z,head[N];
11 pli o,f[N][3];
12 struct ji{
13     int nex,to,len;
14 }edge[N<<1];
15 pli add(pli x,pli y){
16     return make_pair(x.fi+y.fi,x.se+y.se);
17 }
18 void add(int x,int y,int z){
19     edge[E].nex=head[x];
20     edge[E].to=y;
21     edge[E].len=z;
22     head[x]=E++;
23 }
24 void dfs(int k,int fa,ll v){
25     f[k][0]=make_pair(0,0);
26     f[k][1]=f[k][2]=make_pair(-v,1);
27     for(int i=head[k];i!=-1;i=edge[i].nex)
28         if (edge[i].to!=fa){
29             int u=edge[i].to;
30             dfs(u,k,v);
31             memcpy(f[0],f[k],sizeof(f[0]));
32             for(int j=0;j<3;j++)f[k][j]=add(f[k][j],mx(u));
33             f[k][1]=max(f[k][1],add(add(f[0][0],f[u][1]),make_pair(edge[i].len,0)));
34             f[k][2]=max(f[k][2],add(add(f[0][1],f[u][1]),make_pair(edge[i].len+v,-1)));
35         } 
36 }
37 pli pd(ll k){
38     dfs(1,0,k);
39     return mx(1);
40 }
41 int main(){
42     scanf("%d%d",&n,&m);
43     m++;
44     memset(head,-1,sizeof(head));
45     for(int i=1;i<n;i++){
46         scanf("%d%d%d",&x,&y,&z);
47         add(x,y,z);
48         add(y,x,z);
49     }
50     ll l=-oo,r=oo;
51     while (l<r){
52         ll mid=(l+r+1>>1);
53         if (pd(mid).se>=m)l=mid;
54         else r=mid-1;
55     }
56     o=pd(l-1);
57     printf("%lld",o.fi+o.se*(l-1)+l*(m-o.se));
58 }
View Code
原文地址:https://www.cnblogs.com/PYWBKTDA/p/12972254.html