dsu on tree 学习笔记

适用范围

  1. 支持离线处理
  2. 每个询问都是针对某棵子树
  3. 没有修改操作

原理

  • 在递归遍历一整棵树并更新每棵子树产生的影响时,正常情况下我们每次递归完一棵子树都要将其撤销,否则会对其兄弟节点造成干扰
  • (dsu) (on) (tree) 仅利用了一个性质,就是在 dfs 的过程中,最后遍历的那个不需要撤销,因为它的兄弟节点已经遍历完了,这时候我们就可以直接让父节点继承这个点的信息
  • 因为最后的子树不需要撤销,所以我们想要让其越大越好,也就是撤销的越少越好,这样会保证我们的时间效率尽可能地高
  • 核心也就有了:利用树链剖分的性质,最后递归重儿子且不撤销优先遍历轻儿子并撤销父节点继承重儿子的信息

时间复杂度

  • 大部分人看到这里都会觉得:这有什么区别吗?不就少搞一个重儿子吗。然而事实上是,它将 (O(n^2)) 的效率变成了 (O(nlogn)) 的效率
  • 依据树链剖分的性质,一个子树内轻边的数量保证不会超过总边数的一半,考虑最坏的情况,每次 dfs 到一个新节点,那么轻边数就会减半,最后遍历完后就只有了 (logn) 条轻边

例题

T1 树上数颜色
T2 CF600E Lomsat gelral
T3 射手座之日

这两题都是模板题,这里放上 (T1) 的代码作为板子用以参考

#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#define R register
#define N 100010
using namespace std;
inline int read(){
	int x = 0,f = 1;
	char ch = getchar();
	while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return x*f;
}
int n,m,head[N],siz[N],f[N],son[N],cnt[N],c[N];
long long ans[N];
struct edge{
	int to,next;
}e[N<<1];
int len;
void addedge(int u,int v){
	e[++len].to = v;
	e[len].next = head[u];
	head[u] = len;
}
void dfs(int u,int fa){//求重儿子
	siz[u] = 1;
	f[u] = fa;
	for(R int i = head[u];i;i = e[i].next){
		int v = e[i].to;
		if(v==fa)continue;
		dfs(v,u);
		siz[u] += siz[v];
		if(siz[v]>siz[son[u]])son[u] = v;
	}
}
void update(int u,int val,int p){//统计答案,p表示需要跳过的点
	cnt[c[u]] += val;
	if(val==1&&cnt[c[u]]==1)sum++;
	if(val==-1&&cnt[c[u]]==0)sum--;
	for(R int i = head[u];i;i = e[i].next){
		int v = e[i].to;
		if(v==f[u]||v==p)continue;
		update(v,val,p);
	}
}
void dsu(int u,int opt){//opt表示撤销or不撤销对答案的影响
	for(R int i = head[u];i;i = e[i].next){
		int v = e[i].to;
		if(v==f[u]||v==son[u])continue;//先递归轻儿子
		dsu(v,0);
	}
	if(son[u])dsu(son[u],1);//最后递归重儿子
	update(u,1,son[u]);//统计轻儿子的答案,记得跳过重儿子
	ans[u] = sum;
	if(opt==0){//撤销影响
		update(u,-1,0);
		sum = 0;
	}
}

int main(){
	n = read();
	for(R int i = 1;i < n;i++){
		int u = read(),v = read();
		addedge(u,v),addedge(v,u);
	}
	for(R int i = 1;i <= n;i++)c[i] = read();
	dfs(1,0);
	dsu(1,0);
	m = read();
	for(R int i = 1;i <= m;i++){
		int x = read();
		printf("%lld
",ans[x]);
	}
	return 0;
}

参考资料

Hypoc_
自为风月马前卒
The End

原文地址:https://www.cnblogs.com/hhhhalo/p/13776209.html