题目链接
description
给定一颗(n)点边带权无根树,试问树上距离小于(k)点对有几何?
solution
本题为点分治板子题.点分治是一种十分重要的算法,运用了容斥原理和分治思想,其基本思路如下:
题目求的是满足条件的树上路径的条数,我们先对树上路径进行分类:
一:经过根节点的路径
二:不经过根节点的路径
对于经过根节点的路径,我们可以将其拆分为经过其不同子节点的路径进行计算.至于不经过根节点的路径,我们将其视为经过以某一节点为根节点的子树的根节点的路径,于是乎,我们成功将问题拆分为两个不相交的子问题,先处理出问题一,然后递归进子树,将问题二转化为子树的问题一即可.
为了保证算法复杂度,我们需要寻找一个合适的点作为根.这里,我们选用重心作为根,复杂度计算如下:
先介绍一个定理:对于以树的重心为根的有根树,其最大子树大小不超过(frac{n}{2}).
反证法证明如下:假设超过了,其最大子树大小(k>frac{n}{2}),那么将重心往这个子树方向移动,其最大子树一定变小,证毕
于是我们递归的次数是(log n)级别的,复杂度不超过(Omicron(n log^{2} n))
于是乎,我们有如下流程:
首先预处理出当前计算的树的重心,然后根据重心进行计算,算出结果后包含不符合的情况(计算两点同子树),在遍历子树时顺便减掉,然后再一一遍历子树即可.
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<queue>
#include<map>
#include<set>
#define R register
#define next kdjadskfj
#define debug puts("mlg")
#define mod 1000000009
#define Mod(x) ((x%mod+mod)%mod)
using namespace std;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
inline ll read();
inline void write(ll x);
inline void writeln(ll x);
inline void writesp(ll x);
ll n,k;
ll head[200000],next[200000],tot,to[200000],c[200000];
ll Root,vis[200000],Tsiz,size[200000],wt[200000];
ll rem[200000],cnt;
ll ans;
inline void add(ll x,ll y,ll z){to[++tot]=y;next[tot]=head[x];head[x]=tot;c[tot]=z;}
inline void getroot(ll x,ll fa){
size[x]=1;wt[x]=0;
for(R ll i=head[x],ver;i;i=next[i]){
ver=to[i];
if(ver!=fa&&!vis[ver]){
getroot(ver,x);
size[x]+=size[ver];
wt[x]=max(wt[x],size[ver]);
}
}
wt[x]=max(wt[x],Tsiz-size[x]);
if(wt[Root]>wt[x]) Root=x;
}
inline void dfs(ll x,ll D,ll fa){
rem[++cnt]=D;
for(R ll i=head[x],ver;i;i=next[i]){
ver=to[i];
if(ver!=fa&&!vis[ver]) dfs(ver,D+c[i],x);
}
}
inline ll calc(ll x,ll D){
cnt=0;dfs(x,D,0);ll l=1,r=cnt,sum=0;
sort(rem+1,rem+cnt+1);
for(;;++l){
while(r&&rem[l]+rem[r]>k) --r;
if(r<l) break;
sum+=r-l+1;
}
return sum;
}
inline void dfs2(ll x){
ans+=calc(x,0);vis[x]=true;
for(R ll i=head[x],ver;i;i=next[i]){
ver=to[i];
if(!vis[ver]){
ans-=calc(ver,c[i]);
Root=0;Tsiz=size[ver];getroot(ver,0);
dfs2(Root);
}
}
}
int main(){
n=read();
for(R ll i=1,x,y,z;i<n;i++){
x=read();y=read();z=read();
add(x,y,z);add(y,x,z);
}
k=read();
wt[0]=((ull)1<<63)-1;
Tsiz=n;getroot(1,0);
dfs2(Root);
writeln(ans-n);
}
inline ll read(){ll x=0,t=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-') t=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*t;}
inline void write(ll x){if(x<0){putchar('-');x=-x;}if(x<=9){putchar(x+'0');return;}write(x/10);putchar(x%10+'0');}
inline void writesp(ll x){write(x);putchar(' ');}
inline void writeln(ll x){write(x);putchar('
');}