点分治与点分树 学习笔记

一、引入

我们看这样的一道题:给定一棵有\(m\)的点的树,询问树上是否存在距离为\(k\)的点对,有\(m\)组询问,\(nm\le 10^6\).

容易想到的做法是暴力,即对于每一个点对都暴力判断,显然,它的时间复杂度不够优秀。

点分治,就是解决像这道题一样需要大规模处理树上路径的问题的一种好方法。

二、算法流程

考虑这样的情况,我们面对的并不是一棵树,而是一个序列的是有关区间的区间和之类的问题。

对于这样的问题,我们知道可以通过选中点,分治处理左右部分,然后单独考虑跨过中点的区间的贡献,而处理后者往往比处理原问题容易地多。

我们希望将这样的分治策略扩展的树上,但我们该如何定义树上的中点呢?

这里我们引入一个概念——树的重心

它的定义是:在一棵树上,找到一个点,使以这个点作为根时,所有子树中大小最大的一个的大小最小(一棵子树的大小就是这棵子树的节点个数),那么这个点就是重心。

根据定义,一旦我们选择重心,然后递归处理每一棵子树,每一棵子树的大小都不会超过原树的\(\frac 12\),这保证了点分治的复杂度。

于是,我们就能写出算法流程了:

  • 1.找到重心
  • 2.处理经过重心的路径的贡献
  • 3.递归处理重心的所有子树

三、实现

寻找重心部分:

int rt,mxsiz;
inline void getsiz(int u,int f){
	siz[u]=1;
	for(int i=first[u];i;i=e[i].nxt){
		int v=e[i].v;
		if(v==f||vis[v]) continue;
		getsiz(v,u);siz[u]+=siz[v];
	}
	return ;
}
int dis[N],d[N];
inline void findrt(int u,int f,int sum){//sum是目前正在处理的这棵子树的大小
	int mx=sum-siz[u];
	for(int i=first[u];i;i=e[i].nxt){
		int v=e[i].v;
		if(v==f||vis[v]) continue;//vis是对已经选择过的重心打的标记,这样就能防止遍历到其他子树上	
		findrt(v,u,sum);
		mx=max(mx,siz[v]);
	}
	if(mx<mxsiz) mxsiz=mx,rt=u;
}

接下来是分治的实现方式:

首先先找到重心,然后标记重心以防遍历子树时通过重心跑到另一个子树上

然后处理经过重心的路径,也就是这里的\(calc(rt)\)\(calc\)是每道题目都不相同的函数。

一般实现方式是:依次遍历子树,增加这棵子树中的点和已经遍历过的点的贡献

inline void solve(int u){
	rt=0;mxsiz=0x3f3f3f3f;sz=0;
	getsiz(u,0);
	findrt(u,0,siz[u]);
	vis[rt]=1;
	ans+=calc(rt);
	for(int i=first[rt];i;i=e[i].nxt){
		int v=e[i].v;
		if(vis[v]) continue;
		solve(v);
	}
}

接下来我们通过例题来讲解\(calc\)函数的\(2\)种实现方式:

四、例题

洛谷模板题

这就是我们在引入中所提到的题目。

我们可以开一个数组记录已经遍历过的节点到根的距离,增加一棵子树时,就依次考虑这棵树中所有节点到根的距离\(d\),并判断一下是否已经有距离为\(k-d\)的点即可:

#include<bits/stdc++.h>
using namespace std;
const int M=1e7+10;
const int N=2e5+10;
int n,m,k,rt,sz,mxsiz,mxk,ans=0,vis[N],siz[N],cnt,first[N],w[M],K[N],anss[N];
struct node{
	int v,w,nxt;
}e[N<<1];
inline void add(int u,int v,int w){e[++cnt].v=v;e[cnt].w=w;e[cnt].nxt=first[u];first[u]=cnt;}
inline void getsiz(int u,int f){
	siz[u]=1;
	for(int i=first[u];i;i=e[i].nxt){
		int v=e[i].v;
		if(v==f||vis[v]) continue;
		getsiz(v,u);siz[u]+=siz[v];
	}
	return ;
}
int dis[N],d[N];
inline void findrt(int u,int f,int sum){
	int mx=sum-siz[u];
	for(int i=first[u];i;i=e[i].nxt){
		int v=e[i].v;
		if(v==f||vis[v]) continue;	
		findrt(v,u,sum);
		mx=max(mx,siz[v]);
	}
	if(mx<mxsiz) mxsiz=mx,rt=u;
}
inline void getdis(int u,int f){
	dis[++sz]=d[u];
	for(int i=first[u];i;i=e[i].nxt){
		int v=e[i].v;
		if(v==f||vis[v]) continue;
		d[v]=d[u]+e[i].w;
		getdis(v,u);
	}
}
inline void calc(int u){
	vector<int> used; 
	w[0]=1;
	for(int i=first[u];i;i=e[i].nxt){
		int v=e[i].v;
		if(vis[v]) continue;
		sz=0;d[v]=e[i].w;
		getdis(v,u);
		for(int t=1;t<=m;++t){
			if(anss[t]) continue;
			for(int j=1;j<=sz;++j){
				if(dis[j]>K[t]) continue;
				if(w[K[t]-dis[j]]){
					anss[t]=1;
					break;
				}
			}
		}
		for(int j=sz;j>0;j--)
			if(dis[j]<=mxk) used.push_back(dis[j]),w[dis[j]]=1;
	}
	for(int i=0;i<used.size();++i) w[used[i]]=0;
}
inline void solve(int u){
	rt=0;mxsiz=0x3f3f3f3f;sz=0;
	getsiz(u,0);
	findrt(u,0,siz[u]);
	vis[rt]=1;
	calc(rt);
	for(int i=first[rt];i;i=e[i].nxt){
		int v=e[i].v;
		if(vis[v]) continue;
		solve(v);
	}
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1,u,v,w;i<n;++i){
		scanf("%d%d%d",&u,&v,&w);
		add(u,v,w);add(v,u,w);
	}
	for(int i=1;i<=m;++i) scanf("%d",&K[i]),mxk=max(mxk,K[i]);
	solve(1);
	for(int i=1;i<=m;++i)
		if(anss[i]) puts("AYE");
		else puts("NAY"); 
	return 0;
}

聪聪可可

依然没有什么区别,不过这次我们只关心路径长度\(\%3\)的余数:

#include<bits/stdc++.h>
using namespace std;
const int N=2e5+10;
int n,m,k,rt,sz,mmx,mxk,ans=0,vis[N],siz[N],cnt,first[N],w[3];
struct node{
	int v,w,nxt;
}e[N<<1];
inline void add(int u,int v,int w){e[++cnt].v=v;e[cnt].w=w;e[cnt].nxt=first[u];first[u]=cnt;}
inline void getsiz(int u,int f){
	siz[u]=1;
	for(int i=first[u];i;i=e[i].nxt){
		int v=e[i].v;
		if(v==f||vis[v]) continue;
		getsiz(v,u);siz[u]+=siz[v];
	}
	return ;
}
int dis[N],d[N];
inline void findrt(int u,int f,int sum){
	int mx=sum-siz[u];
	for(int i=first[u];i;i=e[i].nxt){
		int v=e[i].v;
		if(v==f||vis[v]) continue;	
		findrt(v,u,sum);
		mx=max(mx,siz[v]);
	}
	if(mx<mmx) mmx=mx,rt=u;
}
inline void getdis(int u,int f){
	dis[++sz]=d[u];
	for(int i=first[u];i;i=e[i].nxt){
		int v=e[i].v;
		if(v==f||vis[v]) continue;
		d[v]=d[u]+e[i].w;
		getdis(v,u);
	}
}
inline void calc(int u){
	vector<int> used; 
	w[0]=1;w[1]=w[2]=0;
	for(int i=first[u];i;i=e[i].nxt){
		int v=e[i].v;
		if(vis[v]) continue;
		sz=0;d[v]=e[i].w;
		getdis(v,u);
		for(int j=1;j<=sz;++j) ans+=w[(3-(dis[j]%3))%3];
		for(int j=1;j<=sz;++j) w[dis[j]%3]++;
	}
}
inline void solve(int u){
	rt=0;mmx=0x3f3f3f3f;sz=0;
	getsiz(u,0);
	findrt(u,0,siz[u]);
	vis[rt]=1;
	calc(rt);
	for(int i=first[rt];i;i=e[i].nxt){
		int v=e[i].v;
		if(vis[v]) continue;
		solve(v);
	}
}
int main(){
	scanf("%d",&n);
	for(int i=1,u,v,w;i<n;++i){
		scanf("%d%d%d",&u,&v,&w);
		add(u,v,w);add(v,u,w);
	}
	solve(1);
	int fz=ans*2+n,fm=n*n,g=__gcd(fz,fm);
	printf("%d/%d\n",fz/g,fm/g);
	return 0;
}

【UR #2】树上GCD

题意:给出一棵以 \(1\) 为根的树,对于 \(i\in [1,n-1]\) 求有多少对 \((u,v)(u<v)\) ,设 \(a=LCA(u,v)\) ,满足 \(\gcd(dis(u,a),dis(v,a))=i\)\(n\le 2\times 10^5\)

\(f(u,v)=\gcd(dis(u,a),dis(v,a))\),我们可以先对每个 \(i\) 求出满足 \(i\mid f(u,v)\)\((u,v)\) 数量,然后可以反演得到答案。

对于这个新问题,容易想到点分治,对于重心 \(i\) 计算树上路径经过 \(i\) 的所有点对 \((u,v)\) 对答案的贡献,分两种情况考虑:

  • \(u,v\) 均在 \(i\) 的子树中(原树上的子树),那么 \(u,v\)\(LCA\) 就是 \(i\),可以直接 \(dfs\) 一遍计算 \(ct_x\) 表示到 \(i\) 距离为 \(x\) 的点数量,然后对于 \([1,n-1]\) 暴力枚举其倍数进行统计,复杂度是 \(\mathcal O(n\log n)\)
  • \(u\)\(i\) 子树中,\(v\) 不在:枚举 \(x=LCA(u,v)\),那么 \(v\) 的数量可以使用和上一种情况同样的方法遍历 \(x\) 的子树得到。但同时我们还要对每个 \(d\)\(i\) 子树内有多少个点到 \(x\) 的距离为 \(d\) 的倍数:如果 \(d>\sqrt{n}\),我们可以直接这样的点到 \(i\) 的距离,利用 \(ct\) 数组进行统计,这样的数不超过 \(\mathcal O(\sqrt{n})\) 个;否则,我们可以预处理 \(f_{d,r}\) 表示到 \(i\) 的距离模 \(d\)\(r\) 的点数量,由于 \(d\) 只有 \(\sqrt{n}\) 个,所以暴力计算 \(f\) 数组依然是 \(\mathcal O(n\sqrt{n})\) 的。

因此单次遍历的复杂度是 \(\mathcal O(n\sqrt{n})\),总复杂度也就是 \(\mathcal O(n\sqrt{n})\)

#include<bits/stdc++.h>
using namespace std;
const int N=2e5+10;
#define pb push_back
typedef long long ll; 
int n,sq,fa[N];
ll ans[N],sum[N],anss[N],de[N];
vector<int> to[N];
#define LOCAL
namespace iobuff{
	const int LEN=1000000;
	char in[LEN+5], out[LEN+5];
	char *pin=in, *pout=out, *ed=in, *eout=out+LEN;
	inline char gc(void)
	{
		#ifdef LOCAL
		return getchar();
		#endif
		return pin==ed&&(ed=(pin=in)+fread(in, 1, LEN, stdin), ed==in)?EOF:*pin++;
	}
	inline void pc(char c)
	{
		pout==eout&&(fwrite(out, 1, LEN, stdout), pout=out);
		(*pout++)=c;
	}
	inline void flush()
	{ fwrite(out, 1, pout-out, stdout), pout=out; }
	template<typename T> inline void read(T &x)
	{
		static int f;
		static char c;
		c=gc(), f=1, x=0;
		while(c<'0'||c>'9') f=(c=='-'?-1:1), c=gc();
		while(c>='0'&&c<='9') x=10*x+c-'0', c=gc();
		x*=f;
	}
	template<typename T> inline void putint(T x, char div)
	{
		static char s[15];
		static int top;
		top=0;
		x<0?pc('-'), x=-x:0;
		while(x) s[top++]=x%10, x/=10;
		!top?pc('0'), 0:0;
		while(top--) pc(s[top]+'0');
		pc(div);
	}
}
using namespace iobuff;
 
 
int rt,vrt,siz[N],vis[N];
ll ss[N];
inline int getsiz(int u,int f){
	siz[u]=1;
	for(int v:to[u])
		if(v!=f&&!vis[v]) siz[u]+=getsiz(v,u);
	return siz[u];
}
inline void getrt(int u,int f,int sum){
	int mx=sum-siz[u];
	for(int v:to[u])
		if(v!=f&&!vis[v]) getrt(v,u,sum),mx=max(mx,siz[v]);
	if(mx<vrt) rt=u,vrt=mx; 
}
int dep[N],cnt[N],mxd,v[1010][1010],ct[N]; 
inline void getdep(int u,int f){
	cnt[dep[u]]++;
	mxd=max(mxd,dep[u]);
	for(int v:to[u])
		if(v!=f&&!vis[v]) dep[v]=dep[u]+1,getdep(v,u); 
}
inline void init(){
	fill(cnt,cnt+mxd+1,0);
	mxd=0;
}

inline void calc(int u,int sum){
	int mx=0;
	for(int i=0;i<=sum;++i) cnt[i]=ct[i]=0;
	for(int v:to[u]){
		if(vis[v]||de[v]<de[u]) continue;
		init();
		dep[v]=1;getdep(v,u);
		for(int i=1;i<=mxd;++i){
			ss[i]=cnt[i];
			for(int j=i<<1;j<=mxd;j+=i) ss[i]+=cnt[j];
			ans[i]+=ss[i]*ct[i];ct[i]+=ss[i]; 
		}
		mx=max(mx,mxd);
	}
	for(int i=mx;i>=1;--i){
		for(int j=i<<1;j<=mx;j+=i) ct[i]-=ct[j];
	}
	mxd=mx;
	int S=(int)(0.4*sqrt(mxd)+0.5);
	ct[0]++;
	for(int i=1;i<=S;++i){
		for(int j=0;j<i;++j) v[i][j]=0;
		for(int j=0;j<=mxd;++j) v[i][j%i]+=ct[j];
	}
	int now=u,last=u,dis=0;
	while(fa[now]&&!vis[fa[now]]){
		last=now;now=fa[now];++dis;
		init();
		for(int v:to[now])
			if(!vis[v]&&v!=last&&v!=fa[now]){
				dep[v]=1;
				getdep(v,now);
			}
		for(int i=1;i<=mxd;++i){
			ll s=0;
			for(int j=i;j<=mxd;j+=i) s+=cnt[j];
			int nxt=(i-dis%i)%i;
			if(i<=S) ans[i]+=s*v[i][nxt];
			else
				for(int j=nxt;j<=mx;j+=i) ans[i]+=s*ct[j];
		}
	}
}

inline void solve(int u){
	rt=0;vrt=0x3f3f3f3f;
	getrt(u,0,getsiz(u,0));
	vis[rt]=1;
	calc(rt,siz[u]);
	for(int v:to[rt])
		if(!vis[v]) solve(v);
}
inline void init_dep(int u,int f){
	for(int v:to[u]) if(v!=f) de[v]=de[u]+1,init_dep(v,u);
}
int main(){
//	freopen("in.in","r",stdin);
//	freopen("out.out","w",stdout);
	read(n);
	sq=ceil(sqrt(n)); 
	for(int i=2,f;i<=n;++i)
		read(f),to[f].pb(i),to[i].pb(f),fa[i]=f;
	init_dep(1,0);
	solve(1);
	for(int i=n-1;i>=1;--i)
		for(int j=i<<1;j<n;j+=i) ans[i]-=ans[j];
	for(int i=1;i<=n;++i) anss[de[i]]++;
	for(int i=n-1;i>=1;--i) anss[i]+=anss[i+1],ans[i]+=anss[i];
	for(int i=1;i<n;++i)
		putint(ans[i],'\n');
	flush();
	return 0;
} 

至此,大家大概已经理解点分治的原理和实现了,接下来我们再介绍一下点分树,又叫动态点分治:

五、点分树

我们还是从模板题讲起:

给一棵\(n\)个点的树,多次询问距离一个点\(u\)距离小于\(k\)的点的点权和,有单点修改,强制在线,数据范围\(10^5\)

考虑暴力的方法:每次查询都暴力点分治:对于询问\(u,k\),每次都找到重心,假设重心与\(u\)的距离为\(d\),将重心其他子树(即不包括\(u\)的子树)中到根距离\(\le d-u\)的点点权加上。

注意到每一次我们都查找的是同样的重心,因此我们可以考虑先找出点分过程中所有的重心,将上一层的重心连接下一层的重心形成一个树,这就是点分树了。

实现大概如下:

inline void getsiz(int u,int f){
	siz[u]=1;
	for(int i=first[u];i;i=e[i].nxt){
		int v=e[i].v;
		if(v==f||vis[v]) continue;
		getsiz(v,u);siz[u]+=siz[v];
	}
	return ;
}
inline void findrt(int u,int f,int sum){
	int ret=sum-siz[u];
	for(int i=first[u];i;i=e[i].nxt){ 
		int v=e[i].v;
		if(v==f||vis[v]) continue;
		findrt(v,u,sum);ret=max(ret,siz[v]);
	}
	if(ret<mx) mx=ret,rt=u;
}
 
inline void divide(int u,int f){
	mx=0x3f3f3f3f;getsiz(u,0);
	findrt(u,0,siz[u]);
	dsiz[rt]=siz[u];
	vis[rt]=1;fa[rt]=f;
	dep[rt]=dep[f]+1;
	int p=rt; 
	for(int i=first[rt];i;i=e[i].nxt){
		int v=e[i].v;
		if(vis[v]) continue;
		divide(v,p);
	}
}

有了点分树,我们就可以每次查询都沿着树走,依次统计路上的答案,显然点分树树高是\(log(n)\)的,如果每一个点的贡献都能较快查询,那我们就能通过此题了:

基本思路大概是对于点分树上每一个点\(u\),开两个数据结构\(S1,S2\)\(S1\)记录点分树上\(u\)的子树到\(u\)的路径的信息,\(S2\)记录点分树上\(u\)的子树到\(u\)的父亲的路径的信息。

那么修改点\(u\)时,暴力爬点分树,修改它对应的\(S1、S2\),复杂度就是数据结构查询复杂度\(*log(n)\)

查询点\(u\)是,暴力向上爬点分树,增加\(S1\)的贡献,减去\(S2\)的贡献即可。

对于模板题,我们每个点开两个动态开点线段树,每个叶子节点表示\(u\)点分树上的子树中到\(u\)\(u\)的父亲距离为\(i\)的点的点权和即可。

那么查询\(x,k\)时,如果我们跳到了\(u\),上一个访问的节点是\(last\),我们就可以在\(u\)\(S1\)中查下标在\(k-dis(u,x)\)以内的前缀和的答案加上,但这样会将\(last\)所在子树中贡献误算进去,于是我们减去,\(last\)\(S2\)\(k-dis(u,x)\)的前缀和即可。

这样我们就能\(\mathcal O(mlog^2n)\)完成此题了。

#include<bits/stdc++.h>
using namespace std;
const int N=4e5+10;
int n,m,fa[N],val[N],first[N],cnt,dep[N],mxd,tot,st[N][20],id[N],Dep[N],dsiz[N];
struct node{
	int v,nxt;
}e[N<<1];
inline void add(int u,int v){e[++cnt].v=v;e[cnt].nxt=first[u];first[u]=cnt;} 
int rt=0,mx,siz[N],vis[N],allrt;
inline void getsiz(int u,int f){
	siz[u]=1;
	for(int i=first[u];i;i=e[i].nxt){
		int v=e[i].v;
		if(v==f||vis[v]) continue;
		getsiz(v,u);siz[u]+=siz[v];
	}
	return ;
}
inline void findrt(int u,int f,int sum){
	int ret=sum-siz[u];
	for(int i=first[u];i;i=e[i].nxt){ 
		int v=e[i].v;
		if(v==f||vis[v]) continue;
		findrt(v,u,sum);ret=max(ret,siz[v]);
	}
	if(ret<mx) mx=ret,rt=u;
}
 
inline void divide(int u,int f){
	mx=0x3f3f3f3f;getsiz(u,0);
	findrt(u,0,siz[u]);
	dsiz[rt]=siz[u];
	vis[rt]=1;fa[rt]=f;
	dep[rt]=dep[f]+1;
	int p=rt; 
	for(int i=first[rt];i;i=e[i].nxt){
		int v=e[i].v;
		if(vis[v]) continue;
		divide(v,p);
	}
}
 
int rt1[N],rt2[N];
namespace SGT{
	int tot,ls[N<<5],rs[N<<5],sum[N<<5]; 
	#define mid (l+r>>1)
	inline void pushup(int p){sum[p]=sum[ls[p]]+sum[rs[p]];} 
	inline void update(int &p,int x,int v,int l,int r){
		if(!p) p=++tot;
		if(l==r){sum[p]+=v;return ;}
		if(x<=mid) update(ls[p],x,v,l,mid);
		else update(rs[p],x,v,mid+1,r);
		pushup(p);
	}
	inline int query(int p,int ql,int qr,int l,int r){
		if(ql<=l&&r<=qr) return sum[p];
		int ret=0;
		if(ql<=mid&&ls[p]) ret+=query(ls[p],ql,qr,l,mid);
		if(qr>mid&&rs[p]) ret+=query(rs[p],ql,qr,mid+1,r);
		return ret;
	}
	#undef mid
}
 
inline int getmn(int x,int y){
	return Dep[x]<Dep[y]?x:y;
}
inline void ST(){
	for(int i=1;i<20;++i)
		for(int j=1;j+(1<<i)-1<=tot;++j)
			st[j][i]=getmn(st[j][i-1],st[j+(1<<i-1)][i-1]);
} 
 
inline int LCA(int x,int y){
	if(id[x]>id[y]) swap(x,y);
	int t=log2(id[y]-id[x]+1);
	return getmn(st[id[x]][t],st[id[y]-(1<<t)+1][t]);	
}
inline int getdis(int u,int v){
	return Dep[u]+Dep[v]-2*Dep[LCA(u,v)];
}
inline int query(int u,int k){
	int last=0,ans=0,now=u;
	while(u!=0){
		int d=getdis(u,now);
		if(d<=k){
			ans+=SGT::query(rt1[u],0,k-d,0,dsiz[u]);
			if(last) ans-=SGT::query(rt2[last],0,k-d,0,dsiz[last]); 
		}
		last=u;u=fa[u];
	}
	return ans;
}
inline void update(int u,int v){
	int now=u;
	int sum=0;
	while(u!=0){
		SGT::update(rt1[u],getdis(u,now),v,0,dsiz[u]);
		if(fa[u]!=0) SGT::update(rt2[u],getdis(fa[u],now),v,0,dsiz[u]);
		u=fa[u];
	}		
} 
 
 
inline void dfs(int u,int f){
	st[++tot][0]=u;id[u]=tot;
	for(int i=first[u];i;i=e[i].nxt){
		int v=e[i].v;
		if(v==f) continue;
		Dep[v]=Dep[u]+1;
		dfs(v,u);
		st[++tot][0]=u;
	}
}
inline void pre(){
	dfs(1,0);
	ST();
	divide(1,0);
	for(int i=1;i<=n;++i) mxd=max(mxd,Dep[i]);
	for(int i=1;i<=n;++i) update(i,val[i]);
}
int main(){
//	freopen("6329.in","r",stdin);
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;++i) scanf("%d",&val[i]);
	for(int i=1,u,v;i<n;++i) scanf("%d%d",&u,&v),add(u,v),add(v,u);
	pre();
	int lastans=0;
	for(int i=1,op,x,y;i<=m;++i){
		scanf("%d%d%d",&op,&x,&y);
		x^=lastans;y^=lastans;
		if(!op) printf("%d\n",lastans=query(x,y));
		else{
			update(x,y-val[x]);
			val[x]=y;
		}
	}
	return 0;
}

原文地址:https://www.cnblogs.com/tqxboomzero/p/14238684.html