「清华集训2016」连通子树 (点分治+dfs序dp+虚树)

「清华集训2016」连通子树 (点分治+dfs序dp+虚树)

丧心病狂系列

首先对于会影响答案的点构建虚树,然后点分治+dfs序dp常见套路。。。

点分治+dfs序dp好题:HDU 5909

由于构建虚树之后\(x,y\)之间的点随便选联通块的方案还需要预处理,最好是倍增吧。。

底层是子树随便选,点之间是倍增处理,都需要换根\(\text{dp}\)预处理

算法嵌套大赛

#include<bits/stdc++.h>

using namespace std;

#define reg register
typedef long long ll;
#define rep(i,a,b) for(int i=a,i##end=b;i<=i##end;++i)
#define drep(i,a,b) for(int i=a,i##end=b;i>=i##end;--i)

#define pb push_back
template <class T> inline void cmin(T &a,T b){ if(a>b) a=b; }
template <class T> inline void cmax(T &a,T b){ if(a<b) a=b; }

char IO;
template <class T=int> T rd(){
	T s=0;
	int f=0;
	while(!isdigit(IO=getchar())) if(IO=='-') f=1;
	do s=(s<<1)+(s<<3)+(IO^'0');
	while(isdigit(IO=getchar()));
	return f?-s:s;
}

const int N=1e5+10,P=1e9+7;

int n,m;
vector <int> G[N],V[N],E[N];
int col[N],a,ca,b,cb,c,cc,fa[N][20],dep[N];
int L[N],dfn;

struct Node{
	ll s,ls,rs,sum;
	Node operator + (const Node __) const {
		Node res;
		res.s=s*__.s%P;
		res.ls=(ls+s*__.ls)%P;
		res.rs=(__.s*rs+__.rs)%P;
		res.sum=(sum+__.sum+rs*__.ls)%P;
		return res;
	}
}s[N][20],tmp[N];//倍增,整段选了,区间左边连续选,右边连续选,整段连续随便选
int tmp2[N];

ll dp[N],g[N],up[N],all[N],Idp[N],Iup[N];
// dp子树随便选,up外面随便选
ll qpow(ll x,ll k) {
	ll res=1;
	for(;k;k>>=1,x=x*x%P) if(k&1) res=res*x%P;
	return res;
}


Node Que(int x,int f) {
	Node res=(Node){1,0,0,0};
	drep(i,18,0) if(dep[fa[x][i]]>dep[f]) res=s[x][i]+res,x=fa[x][i];
	return res;
} // x,f路径上的点随便选,不包括x,f

void pre_dfs(int u,int f) {
	L[u]=++dfn,dep[u]=dep[fa[u][0]=f]+1;
	rep(i,1,18) fa[u][i]=fa[fa[u][i-1]][i-1];
	dp[u]=1;
	for(int v:G[u]) if(v!=f) {
		pre_dfs(v,u);
		dp[u]=dp[u]*(dp[v]+1)%P;
		g[u]=(g[u]+g[v]+dp[v])%P;
	}
}

void redfs(int u,int f) {
	if(f) {
		ll t=dp[f]*qpow(dp[u]+1,P-2)%P;
		s[u][0]=(Node){t,t,t,((t+g[f]-g[u]-dp[u])%P+P)%P};
		rep(i,1,18) if(fa[u][i]) s[u][i]=s[fa[u][i-1]][i-1]+s[u][i-1];
	}
	up[u]=f?(all[f]*qpow(dp[u]+1,P-2)%P):1;
	all[u]=up[u]*dp[u]%P;
	for(int v:G[u]) if(v!=f) redfs(v,u);
	Idp[u]=qpow(dp[u]+1,P-2),Iup[u]=qpow(up[u],P-2);
} // 换根dp预处理

int line[N],cnt,stk[N],top;
int LCA(int x,int y) {
	if(dep[x]<dep[y]) swap(x,y);
	for(int del=dep[x]-dep[y],i=0;(1<<i)<=del;++i) if(del&(1<<i)) x=fa[x][i];
	if(x==y) return x;
	drep(i,18,0) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
	return fa[x][0];
}

void Insert(int u){
	int lca=LCA(stk[top],u);
	if(lca==stk[top]) { E[u].clear(); stk[++top]=u; return; }
	while(top>1 && L[stk[top-1]]>L[lca]) {
		E[stk[top-1]].pb(stk[top]);
		top--;
	}
	if(stk[top-1]!=lca) E[lca].clear();
	E[lca].pb(stk[top--]);
	if(stk[top]!=lca) stk[++top]=lca;
	stk[++top]=u,E[u].clear();
} 

void Construct(){
	sort(line+1,line+cnt+1,[&](const int x,const int y){ return L[x]<L[y]; });
	stk[top=1]=1,E[1].clear();
	rep(i,1,cnt) if(line[i]!=1) Insert(line[i]);
	while(top>1) E[stk[top-1]].pb(stk[top]),top--;
}// 构建虚树

ll ans=0,f[N][6][6][6];
int vis[N],nfa[N]; //虚树上的father
int Up(int u,int f) {
	drep(i,18,0) if(dep[fa[u][i]]>dep[f]) u=fa[u][i];
	return u;
} // 找到u在f下面最高的点

void dfs(int u,int fa) {
	nfa[u]=fa;
	vis[u]=0;
	if(ca==0 && cb==0 && cc==0) ans=(ans+g[u])%P;
	for(int v:E[u]) {
		int t=Up(v,u);
		if(ca==0 && cb==0 && cc==0) ans=((ans-g[t]-dp[t])%P+P)%P;
	}
	for(int v:E[u]) {
		tmp[v]=Que(v,u),tmp2[v]=Up(v,u);
		dfs(v,u);
		Node t=Que(v,u);
		if(ca==0 && cb==0 && cc==0) ans=(ans+t.sum)%P; // 各种神奇的特判,对拍一年,我死了
		E[v].pb(u);
	}
}

int mi,rt,sz[N];
void FindRt(int n,int u,int f) {
	int ma=n-sz[u];
	for(int v:E[u]) if(!vis[v] && v!=f) FindRt(n,v,u),cmax(ma,sz[v]);
	if(mi>ma) mi=ma,rt=u;
}

int TL[N],TR[N],nca,ncb,ncc;// 虚树上的dfs序
void dfs2(int u,int f) {
	nca+=(col[u]==a),ncb+=(col[u]==b),ncc+=(col[u]==c);
	nfa[u]=f;
	TL[u]=++cnt,sz[u]=1,line[cnt]=u;
	for(int v:E[u]) if(!vis[v] && v!=f) {
		dfs2(v,u);
		sz[u]+=sz[v];
	}
	TR[u]=cnt;
}


void Calc(int u){
	if(nca<ca || ncb<cb || ncc<cc) return;
	rep(i,1,cnt) rep(j,0,ca) rep(k,0,cb) rep(d,0,cc) f[i][j][k][d]=0;
	ll x=all[u];
	for(int v:E[u]) {
		if(dep[v]<dep[u]) x=x*Iup[u]%P;
		else x=x*Idp[tmp2[v]]%P;
		if(vis[v]) {
			if(dep[v]<dep[u]) x=x*(tmp[u].rs+1)%P;
			else x=x*(tmp[v].ls+1)%P;
		}
	}
	f[1][col[u]==a][col[u]==b][col[u]==c]=x;
	rep(i,2,cnt) {
		int u=line[i];
		ll x=all[u];
		for(int v:E[u]) {
			if(dep[v]<dep[u]) x=x*Iup[u]%P;
			else x=x*Idp[tmp2[v]]%P;
			if(vis[v]) {
				if(dep[v]<dep[u]) x=x*(tmp[u].rs+1)%P;
				else x=x*(tmp[v].ls+1)%P;
			}
		}
		if(dep[nfa[u]]<dep[u]) {
			Node t=tmp[u];
			rep(na,0,ca) rep(nb,0,cb) rep(nc,0,cc) if(f[i-1][na][nb][nc]) { 
				f[TR[u]][na][nb][nc]=(f[TR[u]][na][nb][nc]+f[i-1][na][nb][nc]*(t.ls+1))%P;
				f[i][na+(col[u]==a)][nb+(col[u]==b)][nc+(col[u]==c)]=(f[i][na+(col[u]==a)][nb+(col[u]==b)][nc+(col[u]==c)]+f[i-1][na][nb][nc]*t.s%P*x)%P;
			}
		} else {
			Node t=tmp[nfa[u]];
			rep(na,0,ca) rep(nb,0,cb) rep(nc,0,cc) if(f[i-1][na][nb][nc]) { 
				f[TR[u]][na][nb][nc]=(f[TR[u]][na][nb][nc]+f[i-1][na][nb][nc]*(t.rs+1))%P;
				f[i][na+(col[u]==a)][nb+(col[u]==b)][nc+(col[u]==c)]=(f[i][na+(col[u]==a)][nb+(col[u]==b)][nc+(col[u]==c)]+f[i-1][na][nb][nc]*t.s%P*x)%P;
			}
		}
    } // dfs序dp
	ans=(ans+f[cnt][ca][cb][cc])%P;
}

void Divide(int u) {
	nca=ncb=ncc=0;
	cnt=0,dfs2(u,0);
	Calc(u),vis[u]=1;
	for(int v:E[u]) if(!vis[v]) {
		mi=1e9,FindRt(sz[v],v,u);
		Divide(rt);
	}
}

int main(){
	n=rd(),m=rd();
	rep(i,1,n) {
		col[i]=rd();
		V[col[i]].pb(i);
	}
	rep(i,2,n) {
		int u=rd(),v=rd();
		G[u].pb(v),G[v].pb(u);
	}
	pre_dfs(1,0),redfs(1,0);
	rep(kase,1,m) {
		a=rd(),ca=rd(),b=rd(),cb=rd(),c=rd(),cc=rd();
		cnt=0;
		if(ca>(int)V[a].size() || cb>(int)V[b].size() || cc>(int)V[c].size()) { puts("0"); continue; }
		for(int v:V[a]) line[++cnt]=v;
		for(int v:V[b]) line[++cnt]=v;
		for(int v:V[c]) line[++cnt]=v;
		Construct();
		ans=0,dfs(1,0);
		Divide(1);
		printf("%lld\n",ans);
	}
}



原文地址:https://www.cnblogs.com/chasedeath/p/12724356.html