P4178 Tree 题解

Link

P4178 Tree

Solve

因为这棵树是无根树,所以我们设节点(p)为根,对于(p)而言,树上的路径可以分成两类,

1.经过根节点(p)

2.包含于(p)的某一个子树中(不经过根节点)

第二类路径我们可以用分治作为子问题,只需要统计第一类路径就好了。

我们把第一类路径分成两段(x)$p$和$p$(y)的,我们遍历整棵子树,统计(dis[x])表示(x)节点到(p)的距离。

对于一棵子树,我们统计子树上的每个节点,累加小于 (K-dis[x])的节点,即在另外一棵子树上和的距离小于(K)的个数,自然而然就想到了用树状数组统计。

因为这道题(K)非常小,所以不用离散化,否则要离散化或者用平衡树代替树状数组。

注意对于每棵子树,要先统计在标记,否则会把这棵子树上的点也上。

Code

#include<bits/stdc++.h>
using namespace std;
const int maxn=40005,maxe=80005,maxk=20005,INF=1<<30;

inline int read(){
	int ret=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-f;ch=getchar();}
	while(ch<='9'&&ch>='0')ret=ret*10+ch-'0',ch=getchar();
	return ret*f;
}
int ans,N,K,cnt,size[maxn],son[maxe],nxt[maxe],lnk[maxn],vis[maxn],w[maxe],c[maxk];

inline void add_e(int x,int y,int z){
	son[++cnt]=y;w[cnt]=z;nxt[cnt]=lnk[x];lnk[x]=cnt;
}

inline void add_x(int x,int date){
	for(int i=x;i<=K;i+=i&-i)
	c[i]+=date;
	return ;
}

inline int get(int x){
	int S=0;
	for(int i=x;i;i-=i&-i)S+=c[i];
	return S;
}

int max(int a,int b){return a>b?a:b;}
void get_size(int x,int fa){
	size[x]=1;
	for(int j=lnk[x];j;j=nxt[j]){
		if(son[j]==fa||vis[son[j]])continue;
		get_size(son[j],x);
		size[x]+=size[son[j]];
	}
}

int max_x=INF,root;
void get_root(int x,int fa,int tot){
	int now_max=0;
	for(int j=lnk[x];j;j=nxt[j]){
		if(son[j]==fa||vis[son[j]])continue;
		now_max=max(now_max,size[son[j]]);
		get_root(son[j],x,tot);
	}
	now_max=max(now_max,tot-size[x]);
	if(now_max<max_x){max_x=now_max,root=x;}
	return ;
}

int s[maxn],sav[maxn];
void DFS(int x,int fa,int dis){
	if(dis>K)return ;
	s[++s[0]]=dis;sav[++sav[0]]=dis;
	for(int j=lnk[x];j;j=nxt[j]){
		if(vis[son[j]]||son[j]==fa)continue;
		DFS(son[j],x,dis+w[j]);
	}
	return ;
}

void solve(int x){
	sav[0]=0;max_x=N;
	get_size(x,0);
	get_root(x,0,size[x]);
	vis[root]=1;
	for(int j=lnk[root];j;j=nxt[j]){
		if(vis[son[j]])continue;
		s[0]=0;DFS(son[j],root,w[j]);
		for(int i=1;i<=s[0];i++){if(s[i]>K)continue;ans+=get(K-s[i]);}
		for(int i=1;i<=s[0];i++){if(s[i]>K)continue;add_x(s[i],1);ans++;}
	}
	for(int i=1;i<=sav[0];i++){
		if(sav[i]>K)continue;
		add_x(sav[i],-1);
	}
	for(int j=lnk[root];j;j=nxt[j]){
		if(!vis[son[j]])solve(son[j]);
	}
	return ;
}

int main(){
	freopen("P4178.in","r",stdin);
	freopen("P4178.out","w",stdout);
	N=read();
	for(int i=1;i<N;i++){
		int x=read(),y=read(),z=read();
		add_e(x,y,z);add_e(y,x,z);
	}
	K=read();
	solve(1);
	printf("%d
",ans);
	return 0;
} 
原文地址:https://www.cnblogs.com/martian148/p/13893353.html