51Nod1868 彩色树 虚树

原文链接https://www.cnblogs.com/zhouzhendong/p/51Nod1868.html

题目传送门 - 51Nod1868

题意

  给定一颗 $n$个点的树,每个点一个 $[1,n]$ 的颜色。设 $g(x,y)$ 表示 $x$ 到 $y$ 的树上路径上有几种颜色。

  对于一个长度为 $n$ 的排列 $P[1cdots n]$ ,定义 $f(P)=sum_{i=1}^{n-1}g(P_i,P_{i+1})$ 。

  现在求对于 $n!$ 个排列,他们的 $f(P)$ 之和 对 $10^9+7$ 取模后的值。

题解

  首先我们考虑每一个 $g(x,y)$ 对于答案的贡献次数。

  考虑捆绑法,把 $x$ 和 $y$ 看作一个整体,显然,它对答案的贡献次数为 $(n-1)!$ 。

  于是答案就是

$$2 imes (n-1)!sum_{x=1}^{n}sum_{y=x+1}^{n} g(x,y)$$

  前面的 $2 imes (n-1)!$ 很好办,现在主要要求后面的那个。

  我们考虑对于每一个颜色分别处理。我们需要求出每一个颜色对答案的贡献。

  记 $f(c,x,y)$ 表示路径 $x$~$y$ 上,如果有颜色 $c$ ,那么值为 $1$ ,否则为 $0$ 。则后面一半变成了:

$$sum_{c=1}^{n}sum_{x=1}^{n}sum_{y=x+1}^{n} f(c,x,y)$$

  确定一种颜色之后,后面的显然非常好求,直接一个树形dp 就差不多了。但是这样的时间复杂度是炸掉的。于是我们需要一个数据结构来优化——虚树。

  建出虚树,然后我们注意一下细节,统计一下就可以了。

  这里推荐一个写的比较详细的虚树学习笔记:https://www.k-xzy.xyz/archives/4476

代码

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N=200005,mod=1e9+7;
int read(){
	int x=0;
	char ch=getchar();
	while (!isdigit(ch))
		ch=getchar();
	while (isdigit(ch))
		x=(x<<1)+(x<<3)+ch-48,ch=getchar();
	return x;
}
struct Gragh{
	static const int M=N*2;
	int cnt,y[M],nxt[M],fst[N];
	void clear(){
		cnt=0;
		memset(fst,0,sizeof fst);
	}
	void add(int a,int b){
		y[++cnt]=b,nxt[cnt]=fst[a],fst[a]=cnt;
	}
}g,t;
int n,c[N],Fac[N],Time=0,now_color,ans=0;
int dfn[N],depth[N],size[N],fa[N][18],sqrsum[N];
int dirson[N],tot[N],st[N],top;
vector <int> id[N];
LL calc(int x){
	return 1LL*x*(x-1)/2;
}
void dfs(int x,int pre,int d){
	dfn[x]=++Time,depth[x]=d,size[x]=1,fa[x][0]=pre,sqrsum[x]=0;
	for (int i=1;i<18;i++)
		fa[x][i]=fa[fa[x][i-1]][i-1];
	for (int i=g.fst[x];i;i=g.nxt[i])
		if (g.y[i]!=pre){
			int y=g.y[i];
			dfs(y,x,d+1);
			size[x]+=size[y];
			sqrsum[x]=(calc(size[y])+sqrsum[x])%mod;
		}
}
int LCA(int x,int y){
	if (depth[x]<depth[y])
		swap(x,y);
	for (int i=17;i>=0;i--)
		if (depth[x]-(1<<i)>=depth[y])
			x=fa[x][i];
	if (x==y)
		return x;
	for (int i=17;i>=0;i--)
		if (fa[x][i]!=fa[y][i])
			x=fa[x][i],y=fa[y][i];
	return fa[x][0];
}
bool cmp(int a,int b){
	return dfn[a]<dfn[b];
}
void solve(int x){
	int dx=dirson[x],sonsqr=tot[x]=0;
	for (int k=t.fst[x];k;k=t.nxt[k]){
		int y=t.y[k],&dy=dirson[y]=y;
		for (int i=17;i>=0;i--)
			if (depth[dy]-(1<<i)>depth[x])
				dy=fa[dy][i];
		solve(y);
		tot[x]+=tot[y];
		sonsqr=(calc(tot[y])+sonsqr)%mod;
	}
	if (c[x]==now_color){
		tot[x]=size[x];
		int xsqr=(calc(size[dx]-size[x])+sqrsum[x])%mod;
		ans=(calc(size[dx])-xsqr+ans)%mod;
	}
	else {
		ans=(calc(tot[x])-sonsqr+ans)%mod;
		for (int i=t.fst[x];i;i=t.nxt[i]){
			int y=t.y[i],v=size[dx]-tot[x]+tot[y]-size[dirson[y]];
			ans=(1LL*tot[y]*v+ans)%mod;
		}
	}
}
int main(){
	n=read();
	for (int i=Fac[0]=1;i<=n;i++)
		c[i]=read(),Fac[i]=1LL*Fac[i-1]*i%mod;
	g.clear();
	for (int i=1;i<n;i++){
		int a=read(),b=read();
		g.add(a,b);
		g.add(b,a);
	}
	dfs(1,0,0);
	for (int i=1;i<=n;i++)
		id[i].clear();
	for (int i=1;i<=n;i++)
		id[c[i]].push_back(i);
	t.clear();
	for (int k=1;k<=n;k++){
		if (id[k].size()<1)
			continue;
		sort(id[k].begin(),id[k].end(),cmp);
		st[top=1]=1,t.fst[1]=0;
		for (vector <int> :: iterator i=id[k].begin();i!=id[k].end();i++){
			int x=*i;
			if (x==1)
				continue;
			int lca=LCA(x,st[top]);
			if (lca!=st[top]){
				while (depth[st[top-1]]>depth[lca])
					t.add(st[top-1],st[top]),top--;
				if (st[top-1]!=lca)
					t.fst[lca]=0,t.add(lca,st[top]),st[top]=lca;
				else
					t.add(lca,st[top--]);
			}
			t.fst[x]=0,st[++top]=x;
		}
		for (int i=1;i<top;i++)
			t.add(st[i],st[i+1]);
		now_color=k,dirson[1]=1;
		solve(1);
	}
	printf("%d
",2LL*(ans+mod)%mod*Fac[n-1]%mod);
	return 0;
}

  

原文地址:https://www.cnblogs.com/zhouzhendong/p/51Nod1868.html