树中点对距离(点分治)

题目

给出一棵带边权的树,问有多少对点的距离<=Len

分析

这是一道点分治的经典题目,可以给点分治的初学者练手。
点分治,顾名思义就是把每个点分开了处理答案。
假设,目前做到了以x为根的子树。
先求出子树中每个点到根的距离(dis),对于两个点(i)(j),如果(dis_{i}+dis_{j}<=k),那么((i,j))就是一个合法的点对。
而点对的路径就会有两种:经过x点的和不经过x点的。
显然,不经过x点的一定会再x的儿子的子树中被计算过。所以,我们要减去不经过x点的。
那怎么把不经过x点的减去呢?
用以x为根的子树的(dis)值(why?如果用以x的儿子为根的子树的(dis),就会有些可以到达x的儿子的却不能到达x的点对,被多减掉),来计算以x的儿子为根的子树中的点对数量,用减去它们就可以了。

记住要找重心

#include <cmath>
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <queue>
const long long maxlongint=2147483647;
using namespace std;
long long dis[12000],next[22000],last[20020],to[20200],n,m,tot,v[20200],d[5000],sum=0,size[20020],mx[20020],f,root,ans;
bool bz[20020];
long long bj(long long x,long long y,long long z)
{
	next[++tot]=last[x];
	last[x]=tot;
	to[tot]=y;
	v[tot]=z;
}
void findroot(long long x,long long fa)
{
	mx[x]=0;
    size[x]=1;
    for(long long i=last[x];i;i=next[i])
    {
        if(to[i]!=fa && (!bz[to[i]])) 
        {
        	findroot(to[i],x);
        	size[x]+=size[to[i]];
        	mx[x]=max(mx[x],size[to[i]]);
		}
    }
    mx[x]=max(mx[x],f-size[x]);
    if (mx[x]<mx[root]) root=x;
    return;
}
void q(long long l,long long r)
{
	long long i=l,j=r,mid=d[(l+r)/2],e;
	while(i<j)
	{
		while(dis[d[i]]<dis[mid]) i++;
		while(dis[d[j]]>dis[mid]) j--;
		if(i<=j)
		{
			e=d[i];
			d[i]=d[j];
			d[j]=e;
			i++;
			j--;
		}
	}
	if(i<r) q(i,r);
	if(l<j) q(l,j);
}
long long dg1(long long x,long long fa)
{
	d[++tot]=x;
	for(long long i=last[x];i;i=next[i])
	{
		long long j=to[i];
		if(fa!=j && (!bz[j]))
		{
			dis[j]=dis[x]+v[i];
			dg1(j,x);
		}
	}
}
long long getsum()
{
	q(1,tot);
	int i=1,j=tot;
	long long y=0;
 	while(i<j)
	{
		if(dis[d[i]]+dis[d[j]]-2>m)
			j--;
		else
		{
			y+=j-i;
			i++;			
		} 
	}
	return y;
}
long long dg(long long x,long long fa)
{
    bz[x]=true;
	dis[x]=1;
	tot=0;
	dg1(x,fa);
    ans+=getsum();
	for(int i=last[x];i;i=next[i])
    {
        int j=to[i];
        if(!bz[j]) 
        {
			dis[j]=v[i]+1;
        	tot=0;
			dg1(j,x);
       		ans-=getsum();
        	f=size[j];
        	root=0;
			findroot(j,x);
        	dg(root,x);
		}
    }
}
int main()
{
	scanf("%lld%lld",&n,&m);
	for(long long i=1;i<=n-1;i++)
	{
		long long x,y,z;
		scanf("%lld%lld%lld",&x,&y,&z);
		bj(x,y,z);
		bj(y,x,z);			
	}
	mx[0]=maxlongint;
	f=n;
	findroot(1,0);
	dg(root,0);
	printf("%lld
",ans);
}
原文地址:https://www.cnblogs.com/chen1352/p/9029689.html