bzoj 4911: [Sdoi2017]切树游戏

题目描述

Solution

考虑暴力DP:设 (f[x][i]) 表示 (x) 子树内, (x) 作为深度最小的点的连通块的数量
(f[x][i]=f[x][j]*f[u][k]\,j igoplus k=i)
这个过程可以用 (FWT) 优化

由于有修改,用链分治动态维护这个DP
按树链剖分的方法,把树分成若干条重链
每一条重链看作一个序列 (P_L,...P_R),按照深度从 (L)(R) 递减的顺序排列,线段树维护

分别记录以下东西:
(sum[x][i]) 表示线段树中 (x) 所代表的区间的异或和为 (i) 的连通块的答案和
(li[x][i]) 表示 线段树中 (x) 所代表的区间中包含左端点的异或和为 (i) 的连通块的答案和
(ri[x][i]) 表示 线段树中 (x) 所代表的区间中包含右端点的异或和为 (i) 的连通块的答案和
(siz[x][i]) 表示 线段树中 (x) 所代表的区间 ([L,R]) 这个完整的异或和为 (i) 的连通块的答案(也就是每一个位置权值的乘积)

同一条链的转移十分简单,考虑链与链之间的转移:
我们把这一条链直接当作 链顶的父亲 的权值就行了
更新的时候在链上暴力跳就行了
复杂度是 (log^2)

考虑这个转移是需要 (FWT) 优化的,复杂度又多了个 (log)

有一种方法优化:
我们 (FWT) 时,是先 (FWT(a,1)),再做点值多项式乘法,再转回来的过程
我们可以一开始就转好点值多项式,然后运算过程全程用点值多项式的值来代入,中间的运算过程就可以变成普通的点值乘法了
在询问的时候再 (FWT) 回来就行了

这样复杂度就是 (O(n*m*log^2)) 的了

另外注意:
(0) 没有逆元,由于会除以 (0),所以要定义一种新运算维护 (0) 的个数,重载一下乘除号就行了

#include<bits/stdc++.h>
#define pb push_back
using namespace std;
const int N=30005,M=130,mod=10007;
int n,m,Q,a[N],sz[N],son[N],dep[N],head[N],nxt[N*2],to[N*2],num=0;
int top[N],fa[N],inv[N],E[M][M],lis[N],tt=0,ans[M],re[M];
vector<int>v[N];
inline void link(int x,int y){nxt[++num]=head[x];to[num]=y;head[x]=num;}
inline void dfs(int x){
	sz[x]=1;
	for(int i=head[x];i;i=nxt[i]){
		int u=to[i];if(sz[u])continue;
		dep[u]=dep[x]+1;fa[u]=x;dfs(u);
		sz[x]+=sz[u];if(sz[u]>sz[son[x]])son[x]=u;
	}
}
inline void dfs2(int x,int tp){
	top[x]=tp;
	if(son[x])dfs2(son[x],tp);
	for(int i=head[x];i;i=nxt[i])
		if(to[i]!=fa[x] && to[i]!=son[x])dfs2(to[i],to[i]);
	v[tp].pb(x);
}
inline void fwt(int *A,int o){
	for(int i=1;i<m;i<<=1)
		for(int j=0;j<m;j+=i<<1)
			for(int k=0;k<i;k++){
				int x=A[j+k],y=A[j+k+i];
				if(!o)A[j+k]=(x+y)%mod,A[j+k+i]=(x-y+mod)%mod;
				else A[j+k]=(x+y)*inv[2]%mod,A[j+k+i]=(x-y+mod)*inv[2]%mod;
			}
}
struct data{
	int a,b;
	inline void biu(int x){x%=mod;if(x)a=x,b=0;else a=1,b=1;}
	inline int val(){return b?0:a;}
	inline void operator *=(const int x){
		if(!x)b++;
		else a=a*x%mod;
	}
	inline void operator /=(const int x){
		if(!x)b--;
		else a=a*inv[x]%mod;
	}
}f[N][M];
void priwork(){
	inv[1]=1;
	for(int i=2;i<mod;i++)inv[i]=(mod-(mod/i)*inv[mod%i]%mod)%mod;
	int len;for(len=1;len<m;len<<=1);m=len;
	for(int i=0;i<m;i++)E[i][i]=1,fwt(E[i],0);    //预处理出单位矩阵 E
       //因为我们是先把 f[i][a[i]]=1 赋为 1 再转点值表达式的,我们预处理出E[i]表示把 i 赋成1时的单位多项式
	for(int i=1;i<=n;i++)
		for(int j=0;j<m;j++)f[i][j].biu(E[a[i]][j]);
}
inline bool comp(int i,int j){return dep[i]>dep[j];}
int ls[N*4],rs[N*4],rt[N],li[N*4][M],ri[N*4][M];
int ft[N*4],sum[N*4][M],siz[N*4][M],id[N];
inline void upd(int o){
	for(int i=0;i<m;i++){
		sum[o][i]=(sum[ls[o]][i]+sum[rs[o]][i]+ri[ls[o]][i]*li[rs[o]][i])%mod;
		li[o][i]=(li[ls[o]][i]+li[rs[o]][i]*siz[ls[o]][i])%mod;
		ri[o][i]=(ri[rs[o]][i]+ri[ls[o]][i]*siz[rs[o]][i])%mod;
		siz[o][i]=siz[ls[o]][i]*siz[rs[o]][i]%mod;
	}
}
inline void build(int &x,int l,int r,int t){
	x=++tt;
	if(l==r){
		id[v[t][l]]=x;
		for(int i=0;i<m;i++)
			li[x][i]=ri[x][i]=sum[x][i]=siz[x][i]=f[v[t][l]][i].val();
		return ;
	}
	int mid=(l+r)>>1;
	build(ls[x],l,mid,t);build(rs[x],mid+1,r,t);
	if(ls[x])ft[ls[x]]=x;if(rs[x])ft[rs[x]]=x;
	upd(x);
}
inline void solve(int x){
	int t=top[x];
	if(fa[t])for(int i=0;i<m;i++)f[fa[t]][i]/=(ri[rt[t]][i]+E[0][i])%mod;
	for(int i=0;i<m;i++)ans[i]=(ans[i]-sum[rt[t]][i]+mod)%mod;
	int p=id[x];
	for(int i=0;i<m;i++)
		li[p][i]=ri[p][i]=sum[p][i]=siz[p][i]=f[x][i].val();
	for(p=ft[p];p;p=ft[p])upd(p);
	if(fa[t])for(int i=0;i<m;i++)f[fa[t]][i]*=(ri[rt[t]][i]+E[0][i])%mod;
	for(int i=0;i<m;i++)ans[i]=(ans[i]+sum[rt[t]][i])%mod;
}
int main(){
  freopen("pp.in","r",stdin);
  freopen("pp.out","w",stdout);
  int x,y;char S[8];
  scanf("%d%d",&n,&m);
  for(int i=1;i<=n;i++)scanf("%d",&a[i]);
  for(int i=1;i<n;i++)scanf("%d%d",&x,&y),link(x,y),link(y,x);
  dep[1]=1;dfs(1);dfs2(1,1);
  priwork();
  int cnt=0;
  for(int i=1;i<=n;i++)if(top[i]==i)lis[++cnt]=i;
  sort(lis+1,lis+cnt+1,comp);
  for(int i=1;i<=cnt;i++){
	  x=lis[i];
	  build(rt[x],0,v[x].size()-1,x);
	  if(fa[x])
		  for(int j=0;j<m;j++)f[fa[x]][j]*=(ri[rt[x]][j]+E[0][j])%mod;
	  for(int j=0;j<m;j++)ans[j]=(ans[j]+sum[rt[x]][j])%mod;
  }
  cin>>Q;
  while(Q--){
	  scanf("%s%d",S,&x);
	  if(S[0]=='Q'){
		  for(int i=0;i<m;i++)re[i]=ans[i];
		  fwt(re,1);
		  printf("%d
",re[x]);
	  }
	  else{
		  scanf("%d",&y);
		  for(int i=0;i<m;i++)f[x][i]/=E[a[x]][i];
		  a[x]=y;
		  for(int i=0;i<m;i++)f[x][i]*=E[a[x]][i];
		  for(;x;x=fa[top[x]])solve(x);
	  }
  }
  return 0;
}

原文地址:https://www.cnblogs.com/Yuzao/p/8576879.html