AT3611Tree MST【点分治,最小生成树】

正题

题目链接:https://www.luogu.com.cn/problem/AT3611


题目大意

给出\(n\)个点的一棵树。

现在有一张完全图,两个点之间的边权为\(w_x+w_y+dis(x,y)\)\(dis\)表示树上距离)

求这张完全图的最小生成树。

\(2\leq n\leq 2\times 10^5,1\leq w_i,c_i\leq 10^9\)


解题思路

考虑可能作为最小生成树的边。

一个结论就是对于一个子图。不在最小生成森林上的边一定不在原图的最小生成树上。

这样可以考虑分治,点分治之后对于根节点\(x\),其他的节点定义\(f_x=dep_x+w_x\),那么两个点之间边权就是\(f_x+f_y\)了(\(x,y\)属于不同子树),对于同一子树的我们也加进去,因为这是不优的边所以不会影响答案。

此时图中的最小生成森林是其他所有点连接\(f\)值最小的点。

这样我们可以处理出\(n\log n\)条可能的边,在这些边上再跑一次最小生成树就好了。

时间复杂度\(O(n\log^2 n)\)


code

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const ll N=2e5+10,inf=1e18;
struct node{
	ll to,next,w;
}a[N<<1];
struct edge{
	ll x,y,w;
}e[N<<5];
ll n,tot,mins,root,ans,num,ent;
ll ls[N],f[N],siz[N],w[N],fa[N];
bool v[N];
void addl(ll x,ll y,ll w){
	a[++tot].to=y;
	a[tot].next=ls[x];
	ls[x]=tot;a[tot].w=w;
	return;
}
void groot(ll x,ll fa){
	siz[x]=1;f[x]=0;
	for(ll i=ls[x];i;i=a[i].next){
		ll y=a[i].to;
		if(y==fa||v[y])continue;
		groot(y,x);siz[x]+=siz[y];
		f[x]=max(f[x],siz[y]);
	}
	f[x]=max(f[x],num-siz[x]);
	if(f[x]<f[root])root=x;
	return;
}
void calc(ll x,ll fa,ll dep){
	f[x]=w[x]+dep;
	if(f[x]<f[mins])mins=x;
	for(ll i=ls[x];i;i=a[i].next){
		ll y=a[i].to;
		if(y==fa||v[y])continue;
		calc(y,x,dep+a[i].w);
	}
	return;
}
void adde(ll x,ll fa){
	e[++ent]=(edge){x,mins,f[x]+f[mins]};
	for(ll i=ls[x];i;i=a[i].next){
		ll y=a[i].to;
		if(y==fa||v[y])continue;
		adde(y,x);
	}
}
void solve(ll x){
	v[x]=1;f[x]=w[mins=x];
	for(ll i=ls[x];i;i=a[i].next){
		ll y=a[i].to;
		if(v[y])continue;
 		calc(y,x,a[i].w);
	}
	e[++ent]=(edge){x,mins,f[x]+f[mins]};
	for(ll i=ls[x];i;i=a[i].next){
		ll y=a[i].to;
		if(v[y])continue;
		adde(y,x);
	}
	ll sum=num;
	for(ll i=ls[x];i;i=a[i].next){
		ll y=a[i].to;
		if(v[y])continue;
		num=(siz[y]>siz[x])?(sum-siz[x]):siz[y];
		root=0;groot(y,x);solve(root);
	}
	return;
}
bool cmp(edge x,edge y)
{return x.w<y.w;}
ll find(ll x)
{return (fa[x]==x)?x:(fa[x]=find(fa[x]));}
signed main()
{
	scanf("%lld",&n);
	for(ll i=1;i<=n;i++)
		scanf("%lld",&w[i]),fa[i]=i;
	for(ll i=1;i<n;i++){
		ll x,y,w;
		scanf("%lld%lld%lld",&x,&y,&w);
		addl(x,y,w);addl(y,x,w);
	}
	num=n;f[0]=inf;
	groot(1,1);solve(root);
	sort(e+1,e+1+ent,cmp);
	for(ll i=1;i<=ent;i++){
		ll x=e[i].x,y=e[i].y;
		x=find(x);y=find(y);
		if(x!=y)ans+=e[i].w,fa[y]=x;
	}
	printf("%lld\n",ans);
	return 0;
}
原文地址:https://www.cnblogs.com/QuantAsk/p/14402506.html