树上游戏

IV.树上游戏

考虑淀粉质。

对于一棵分治树,我们考虑对树中所有LCA为根节点的路径计算贡献。

我们对于根节点一棵子树一棵子树地处理。设\(cnt_i\)表示子树外有多少条以根节点为一个端点的路径上有颜色\(i\)。则我们当前子树中的一个点的贡献可以分作两部分:子树外的部分(即\(\sum cnt_i\))以及子树内的部分(即从当前节点到子树根的路径上点的贡献)。

如果一个点是该路径上第一个本颜色的点,则会有\(size-cnt_{col_i}\)的贡献,其中\(size\)是除该子树外其他点的数量,而\(col_i\)是当前点的颜色。并且,这个贡献还会对子树中每一个点都有贡献,因此要作为一个标记传递下去(类似线段树的标记永久化)。至于如何判断是否是第一个出现的点,这需要开数组标记该颜色是否出现过,如果没有出现过就增加tag即可。

淀粉质过程中所有步骤都必须保证\(O(n)\),包括但不限于求\(cnt\),求\(\sum cnt\)以及求\(size\)等。具体操作可以通过先求出整棵树的\(cnt\)然后挖去当前子树的\(cnt\)得到。

代码(吸氧可过,否则TLE80):

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
int n,col[100100],sz[100100],msz[100100],SZ,ROOT;
ll cnt[100100],SC,SIZE,res[100100];
bool occ[100100];
vector<int>v[100100];
bool vis[100100];
void getroot(int x,int fa){
	sz[x]=1,msz[x]=0;
	for(auto y:v[x])if(!vis[y]&&y!=fa)getroot(y,x),sz[x]+=sz[y],msz[x]=max(msz[x],sz[y]);
	msz[x]=max(msz[x],SZ-sz[x]);
	if(msz[x]<msz[ROOT])ROOT=x;
}
void getroute(int x,int fa,int tp){//tp controls whether to add colored routes or to subtract
	SIZE+=tp,sz[x]=1;
	bool pref=occ[col[x]];
	occ[col[x]]=true;
	for(auto y:v[x])if(!vis[y]&&y!=fa)getroute(y,x,tp),sz[x]+=sz[y];
	if(!pref)cnt[col[x]]+=sz[x]*tp,SC+=sz[x]*tp;
	occ[col[x]]=pref;
}
void getsubtree(int x,int fa,ll delta){
	bool pref=occ[col[x]];
	if(!occ[col[x]])delta+=SIZE-cnt[col[x]],occ[col[x]]=true;
	res[x]+=delta;
	for(auto y:v[x])if(!vis[y]&&y!=fa)getsubtree(y,x,delta);
	occ[col[x]]=pref;
}
void getans(int x){
	getroute(x,0,1),res[x]+=SC;
//	for(int i=1;i<=n;i++)printf("%d ",sz[i]);puts("");
//	for(int i=1;i<=m;i++)printf("%d ",cnt[i]);puts("");
	occ[col[x]]=true;
	for(auto y:v[x]){
		if(vis[y])continue;
		getroute(y,x,-1);
		cnt[col[x]]-=sz[y],SC-=sz[y]; 
//		printf("SON:%d:",y);for(int i=1;i<=m;i++)printf("%d ",cnt[i]);puts("");
//		printf("%d %d\n",SC,SIZE);
		getsubtree(y,x,SC);
		cnt[col[x]]+=sz[y],SC+=sz[y]; 
		getroute(y,x,1);
//		for(int i=1;i<=n;i++)printf("%d ",res[i]);puts("");
	}
	occ[col[x]]=false;
	getroute(x,0,-1);
}
void solve(int x){
	getans(x),vis[x]=true;
	for(auto y:v[x])if(!vis[y])ROOT=0,SZ=sz[y],getroot(y,x),solve(ROOT);
}
void read(int &x){
	x=0;
	char c=getchar();
	while(c>'9'||c<'0')c=getchar();
	while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=getchar();
}
int main(){
	read(n);
	for(int i=1;i<=n;i++)read(col[i]);
	for(int i=1,x,y;i<n;i++)read(x),read(y),v[x].push_back(y),v[y].push_back(x);
	msz[0]=n+1,SZ=n,getroot(1,0),solve(ROOT);
	for(int i=1;i<=n;i++)printf("%lld\n",res[i]);
	return 0;
}

原文地址:https://www.cnblogs.com/Troverld/p/14605780.html