【JZOJ5363】【NOIP2017提高A组模拟9.14】生命之树 Trie+启发式合并

题面


45

在比赛中,我只想到了45分的暴力。
对于一个树中点对,相当于在他们的LCA及其祖先加上这个点对的贡献。
那么这个可以用dfs序+树状数组来维护。

100

想法

我想到了可能要用trie树来维护这个字符串的公共前缀。
然后这就面临了两个很严重的问题。
1.我对于每个子树都要建一个trie,所以这是(O(n^2))的复杂度。
我想到了要合并儿子的信息,但是这个合并似乎是无法存储。
2.我还要处理xor的问题,我的想法是在trie上的每个结点上维护一个蜜汁容器。
可能这要用到xor的某些运算法则,但我并不知道如何实现。

然后正解就恰好解决了我这两个问题。

zrO lhy Orz

1.trie数可以使用启发式合并,那么时间复杂度就降为(O(nlogn))
合并的时候,可以抛弃掉子树的信息,所以空间复杂度不会超过(O(n))
2.xor我们考虑按位分治,那么我们给trie上的每个结点维护一个(cnt[i][j][0/1])
表示这个结点(i)为根的子树内,有多少个数的二进制下第(j)位为(0/1)的个数。
这个在trie合并时可以简单合并。同时在合并的时候就能利用这个(cnt)统计答案。
具体就不展开,也就是(cnt(*)(*)[0]*cnt(*)(*)[1])之类的。

Code

#include<bits/stdc++.h>
#define ll long long
#define fo(i,x,y) for(int i=x;i<=y;i++)
#define fd(i,x,y) for(int i=x;i>=y;i--)
#define ln(x,y) int(log(x)/log(y))
using namespace std;
const char* fin="1.in";
const char* fout="1.out";
const int inf=0x7fffffff;
int read(){
	int x=0;
	char ch=getchar();
	while (ch<'0' || ch>'9') ch=getchar();
	while (ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar();
	return x;
}
const int maxn=100007,maxm=maxn*2,maxt=600007;
int fi[maxn],la[maxm],ne[maxm],tot;
void add_line(int a,int b){
	tot++;
	ne[tot]=fi[a];
	la[tot]=b;
	fi[a]=tot;
}
void add(int a,int b){add_line(a,b);add_line(b,a);}
int n,a[maxn],rt[maxn],si[maxn],num;
ll ans[maxn];
struct node{
	int ne[26],cnt[17][2],cn[17][2];
}ac[maxt];
int b[maxn][2],hd,tl;
void dfs(int p,int _p,int de,ll &z){
	fo(i,0,25){
		int x=ac[p].ne[i],y=ac[_p].ne[i];
		if (x){
			fo(j,0,16) z+=1ll*ac[x].cnt[j][0]*(ac[_p].cnt[j][1]-ac[y].cnt[j][1])*(1<<j)*de,z+=1ll*ac[x].cnt[j][1]*(ac[_p].cnt[j][0]-ac[y].cnt[j][0])*(1<<j)*de;
			if (y) dfs(x,y,de+1,z);
		}
	}
	fo(j,0,16) z+=1ll*ac[p].cn[j][0]*ac[_p].cnt[j][1]*de*(1<<j),z+=1ll*ac[p].cn[j][1]*ac[_p].cnt[j][0]*de*(1<<j);
}
void link(int p,int _p){
	fo(i,0,16){
		ac[p].cn[i][0]+=ac[_p].cn[i][0];
		ac[p].cn[i][1]+=ac[_p].cn[i][1];
		ac[p].cnt[i][0]=ac[p].cn[i][0];
		ac[p].cnt[i][1]=ac[p].cn[i][1];
	}
	fo(i,0,25){
		int x=ac[p].ne[i],y=ac[_p].ne[i];
		if (x && y) link(x,y);
		else if (y) ac[p].ne[i]=y;
		if (ac[p].ne[i]){
			int x=ac[p].ne[i];
			fo(i,0,16){
				ac[p].cnt[i][0]+=ac[x].cnt[i][0];
				ac[p].cnt[i][1]+=ac[x].cnt[i][1];
			}
		}
	}
}
void merge(int x,int y,ll &z){
	dfs(rt[x],rt[y],0,z);
	link(rt[x],rt[y]);
	si[x]+=si[y];
}
int main(){
	freopen(fin,"r",stdin);
	freopen(fout,"w",stdout);
	scanf("%d",&n);
	fo(i,1,n) scanf("%d",&a[i]);
	fo(i,1,n){
		char ch=getchar();
		while (ch<'a' || ch>'z') ch=getchar();
		rt[i]=++num;
		int x=rt[i];
		while (ch>='a' && ch<='z'){
			fo(k,0,16) ac[x].cnt[k][a[i]>>k&1]++;
			int y=ch-'a';
			si[i]++;
			x=ac[x].ne[y]=++num;
			ch=getchar();
		}
		fo(k,0,16) ac[x].cnt[k][a[i]>>k&1]++,ac[x].cn[k][a[i]>>k&1]++;
	}
	fo(i,1,n-1) add(read(),read());
	hd=tl=0;
	b[++tl][0]=1;
	while (hd++<tl){
		int v=b[hd][0],from=b[hd][1];
		for(int k=fi[v];k;k=ne[k])
			if (la[k]!=from) b[++tl][0]=la[k],b[tl][1]=v;
	}
	fd(i,tl,1){
		int v=b[i][0],from=b[i][1];
		int mx=v;
		for(int k=fi[v];k;k=ne[k])
			if (la[k]!=from){
				ans[v]+=ans[la[k]];
				if (!mx || si[mx]<si[la[k]]) mx=la[k];
			}
		if (mx!=v) merge(mx,v,ans[v]);
		for(int k=fi[v];k;k=ne[k])
			if (la[k]!=from && la[k]!=mx){
				merge(mx,la[k],ans[v]);
			}
		rt[v]=rt[mx];
	}
	fo(i,1,n) printf("%lld
",ans[i]);
	return 0;
}
原文地址:https://www.cnblogs.com/hiweibolu/p/7528884.html