[bzoj1977]次小生成树

先求出最小生成树,然后维护f1/f2[i][j]表示i到$2^{j}-1$祖先中最大和严格次大边,枚举生成树外的每一条边并查询这条边两点间的最大边和严格次大边,若最大边<插入边,就用插入边替换最大边计算答案,否则用插入边替换次大边计算答案

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 #define N 100005
 4 #define oo 0x3f3f3f3f
 5 struct ji{
 6     int x,y;
 7 }o,dp[N][21];
 8 struct ji2{
 9     int x,y,z;
10     bool operator < (const ji2 &k)const{
11         return z<k.z;
12     }
13 }e[N*3];
14 struct ji3{
15     int nex,to,len;
16 }edge[N<<1];
17 int E,n,m,x,ans,head[N],in[N],out[N],vis[N],f[N][21];
18 long long sum;
19 int find(int k){
20     if (k==f[k][0])return k;
21     return f[k][0]=find(f[k][0]);
22 }
23 bool pd(int x,int y){
24     return (in[x]<=in[y])&&(out[y]<=out[x]);
25 }
26 void add(int x,int y,int z){
27     edge[E].nex=head[x];
28     edge[E].to=y;
29     edge[E].len=z;
30     head[x]=E++;
31 }
32 ji merge(ji x,ji y){
33     ji z;
34     z.x=max(x.x,y.x);
35     if (z.x==x.x)z.y=x.y;
36     else z.y=x.x;
37     if (z.x==y.x)z.y=max(z.y,y.y);
38     else z.y=max(z.y,y.x);
39     return z;
40 }
41 void dfs(int k,int fa,int sh){
42     in[k]=++x;
43     f[k][0]=fa;
44     dp[k][0]=ji{sh,0};
45     for(int i=1;i<=20;i++){
46         f[k][i]=f[f[k][i-1]][i-1]; 
47         dp[k][i]=merge(dp[k][i-1],dp[f[k][i-1]][i-1]);
48     }
49     for(int i=head[k];i!=-1;i=edge[i].nex)
50         if (edge[i].to!=fa)dfs(edge[i].to,k,edge[i].len);
51     out[k]=++x;
52 }
53 ji calc(int x,int y){
54     ji z=ji{0,0};
55     for(int i=20;i>=0;i--)
56         if (!pd(f[x][i],y)){
57             z=merge(z,dp[x][i]);
58             x=f[x][i];
59         }
60     for(int i=20;i>=0;i--)
61         if (!pd(f[y][i],x)){
62             z=merge(z,dp[y][i]);
63             y=f[y][i];
64         }
65     if (!pd(x,y))z=merge(z,dp[x][0]);
66     if (!pd(y,x))z=merge(z,dp[y][0]);
67     return z;
68 }
69 int main(){
70     scanf("%d%d",&n,&m);
71     for(int i=1;i<=m;i++)scanf("%d%d%d",&e[i].x,&e[i].y,&e[i].z);
72     sort(e+1,e+m+1);
73     memset(head,-1,sizeof(head));
74     for(int i=1;i<=n;i++)f[i][0]=i;
75     for(int i=1;i<=m;i++)
76         if (find(e[i].x)!=find(e[i].y)){
77             add(e[i].x,e[i].y,e[i].z);
78             add(e[i].y,e[i].x,e[i].z);
79             f[find(e[i].x)][0]=e[i].y;
80             vis[i]=1;
81             sum+=e[i].z;
82         }
83     dfs(1,1,0);
84     ans=0x3f3f3f3f;
85     for(int i=1;i<=m;i++)
86         if (!vis[i]){
87             o=calc(e[i].x,e[i].y);
88             if (e[i].z!=o.x)ans=min(ans,e[i].z-o.x);
89             else
90                 if (o.y)ans=min(ans,e[i].z-o.y);
91         }
92     printf("%lld",ans+sum);
93 }
View Code
原文地址:https://www.cnblogs.com/PYWBKTDA/p/11329411.html