bzoj 5287: [Hnoi2018]毒瘤

Description

Solution

(dfs) 出一棵生成树之后,多出来的边就都是反祖边了
把反祖边两个端点都拿出来,就会得到最多 (k=2*(m-n+1)) 个关键点
除了关键点以外的点转移都是一样的,我们可以预处理出来

关键点数量不多,我们 (2^k) 枚举状态,然后像树形 (DP) 一样转移就行了
转移需要构一棵虚树,对于虚树上的一条边,对应在原树上的一条链转移也是一样的
如果知道了虚树上 (x)(DP) 值,(f[x][0],f[x][1]),那么就可以推出虚树上的父亲的值 (f[fa[x]][0],f[fa[x]][1])
大致可以表示成这样的形式:(f[fa[x]][0]=k0*f[x][0]+k1*f[x][1]),(f[fa[x]][1]) 同理

对于转移系数和虚树上某些节点的 (DP) 初值都可以 (O(n*k)) 的预处理出来
对于一条边 ((x,y)),只有三种状态:存在 (x),存在 (y),都不存在,所以状态数实际上只有 (3^{frac{k}{2}})
复杂度是 (O(n*k+3^{frac{k}{2}}*k))

#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10,mod=998244353;
int n,m,head[N],nxt[N*4],to[N*4],num=1,st[N],top=0,sq[N],fa[N][20];
int ST[N*2],TOP=0,dfn[N],DFN=0,tp=0,dep[N],q[N],r=0,Head[N],id[N];
inline bool comp(int i,int j){return dfn[i]<dfn[j];}
inline void link(int x,int y){nxt[++num]=head[x];to[num]=y;head[x]=num;}
inline void Link(int x,int y){nxt[++num]=Head[x];to[num]=y;Head[x]=num;}
inline int LCA(int x,int y){
	if(dep[x]<dep[y])swap(x,y);
	int deep=dep[x]-dep[y];
	for(int i=19;i>=0;i--)if(deep>>i&1)x=fa[x][i];
	if(x==y)return x;
	for(int i=19;i>=0;i--)
		if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
	return fa[x][0];
}
bool vis[N],et[N*4];int imp[N],lim,ans=0,f[N][2],lis[N];
struct data{
	int k0,k1;
	data(){}
	data(int _k0,int _k1){k0=_k0;k1=_k1;}
	inline data operator +(data &t){return data((k0+t.k0)%mod,(k1+t.k1)%mod);}
	inline data operator *(int t){
		return data(1ll*k0*t%mod,1ll*k1*t%mod);}
	inline int F(int x,int y){return (1ll*x*k0+1ll*y*k1)%mod;}
}k[N][2];
inline void build(int x,int last){
	vis[x]=1;dfn[x]=++DFN;
	for(int i=head[x];i;i=nxt[i]){
		if(i==last)continue;
		int u=to[i];
		if(!vis[u])fa[u][0]=x,dep[u]=dep[x]+1,build(u,i^1);
		else if(dep[u]<dep[x])st[++top]=x,sq[top]=u,et[i]=et[i^1]=1;
	}
}
int dp[N][2];bool d[N];
inline void calc(int x,int la){
	dp[x][0]=dp[x][1]=1;
	for(int i=head[x];i;i=nxt[i]){
		int u=to[i];
		if(u==la || u==fa[x][0] || d[u])continue;
		calc(u,la);
		dp[x][0]=1ll*dp[x][0]*(dp[u][0]+dp[u][1])%mod;
		dp[x][1]=1ll*dp[x][1]*dp[u][0]%mod;
	}
}
inline void getit(int S,int T){
	int x=S;
	k[x][0]=data(1,0);k[x][1]=data(0,1);
   while(fa[x][0]!=T){
		calc(fa[x][0],x);
		data t=k[S][0];d[fa[x][0]]=1;
		k[S][0]=(k[S][0]+k[S][1])*dp[fa[x][0]][0];
		k[S][1]=t*dp[fa[x][0]][1];
		x=fa[x][0];
	}
}
inline void DFS(int x){
	for(int i=Head[x];i;i=nxt[i]){
		DFS(to[i]);
		getit(to[i],x);
	}
	dp[x][0]=dp[x][1]=1;
	for(int i=head[x];i;i=nxt[i]){
		int u=to[i];
		if(u==fa[x][0] || d[u] || et[i])continue;
		calc(u,x);
		dp[x][0]=1ll*dp[x][0]*(dp[u][0]+dp[u][1])%mod;
		dp[x][1]=1ll*dp[x][1]*dp[u][0]%mod;
	}
}
inline void dfs(int x){
	f[x][0]=dp[x][0];f[x][1]=dp[x][1];
	for(int i=Head[x];i;i=nxt[i]){
		int u=to[i],f0,f1;
		dfs(u);
		f0=k[u][0].F(f[u][0],f[u][1]);
		f1=k[u][1].F(f[u][0],f[u][1]);
		f[x][0]=1ll*f[x][0]*(f0+f1)%mod;
		f[x][1]=1ll*f[x][1]*f0%mod;
	}
	if(imp[x]!=-1)f[x][imp[x]^1]=0;
}
int main(){
  freopen("duliu.in","r",stdin);
  freopen("duliu.out","w",stdout);
  int x,y;
  scanf("%d%d",&n,&m);
  for(int i=1;i<=m;i++){
	  scanf("%d%d",&x,&y);
	  link(x,y);link(y,x);
  }
  
  dep[1]=1;build(1,-1);
  for(int j=1;j<20;j++)
	  for(int i=1;i<=n;i++)fa[i][j]=fa[fa[i][j-1]][j-1];
  for(int i=1;i<=top;i++)ST[++TOP]=st[i],ST[++TOP]=sq[i];
  sort(ST+1,ST+TOP+1,comp);
  tp=unique(ST+1,ST+TOP+1)-ST-1;
  int cnt=0;
  for(int i=1;i<=tp;i++)lis[++cnt]=ST[i];
  lis[++cnt]=1;
  sort(lis+1,lis+cnt+1,comp);
  cnt=unique(lis+1,lis+cnt+1)-lis-1;
  for(int i=1;i<=tp;i++)id[ST[i]]=i-1;

  q[++r]=lis[1];
  for(int i=2;i<=cnt;i++){
	  x=lis[i];
	  int lca=LCA(x,lis[i-1]);d[lca]=1;
	  while(r && dfn[q[r]]>dfn[lca]){
		  if(dfn[q[r-1]]>dfn[lca])Link(q[r-1],q[r]);
		  else {
			  Link(lca,q[r]);r--;
			  if(q[r]!=lca)q[++r]=lca;break;
		  }
		  r--;
	  }
	  q[++r]=lis[i];
  }
  while(r>1)Link(q[r-1],q[r]),r--;

  for(int i=1;i<=cnt;i++)d[lis[i]]=1;
  DFS(1);lim=(1<<tp)-1;
  memset(imp,-1,sizeof(imp));
  for(int i=0;i<=lim;i++){
	  bool flag=1;
	  for(int j=1;j<=top;j++)
		  if((i>>id[st[j]]&1) && (i>>id[sq[j]]&1)){flag=0;break;}
	  if(!flag)continue;
	  for(int j=1;j<=tp;j++)imp[ST[j]]=i>>(j-1)&1;
	  dfs(1);
	  ans=((ans+f[1][0])%mod+f[1][1])%mod;
  }
  cout<<ans<<endl;
  return 0;
}

原文地址:https://www.cnblogs.com/Yuzao/p/8883096.html