[WC2018]州区划分(FWT)

题目描述

题解

这道题的思路感觉很妙。

题目中有一个很奇怪的不合法条件,貌似和后面做题没有什么关系,所以我们先得搞掉它。

也就是判断一个点集是否合法,也就是判断这个点集是否存在欧拉回路。

如果存在欧拉回路每个点的度都得是偶数而且图联通,这个条件扫描一遍在上一个并查集就可以判掉了。

然后开始统计答案。

n很小,可以考虑状压dp,我们设dp[s]为已经划分好的州区点集和为s它的所有方案的答案的和。

转移可以考虑枚举子集。

dp[s]=∑dp[s']*(sum[s^s']/sum[s])p

然后我们发现sum[s]p是和转移的枚举无关的,所以我们可以稍稍变换一下变成
dp[s]*sum[s]p=∑dp[s']*sum[s^s']p

这样的复杂度是3n的,我们要考虑优化。

我们换一种枚举方式

dp[S]*sum[s]p=∑dp[s]*sum[s']p  (s&s'==0)&&(s|s'==S)

如果没有前面那个条件,那么它就是一个形式有点奇怪的或卷积,在加上那个烦人的条件,就有点让人怀疑人生了。。

然后关键的思路来了,考虑到这道题时限较长,复杂度可以在卷积的基础上加一个n,所以我们把dp状态多开一维。

我们设dp[i][S]表示在S点集中有i个点的答案。

这样的状态设计虽然带来了大量冗余状态,却方便了我们的转移。

dp[i][S]=∑dp[j][s]*sum[i-j][s']p  (s|s'==S)

这样我们成功的用一个n的时间复杂度把这个东西变成了一个正常的或卷积,直接上FWT就可以了。

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#define N 22
#define R register
using namespace std;
typedef long long ll;
int n,m,p,lowb[1<<21],cou[1<<21],size,du[N],fa[N];
ll f[N][1<<21],g[N][1<<21],w[N],sum[1<<21],ny[1<<21];
bool a[N][N],b[N];
const int mod=998244353;
inline int rd(){
    int x=0;char c=getchar();bool f=0;
    while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    return f?-x:x; 
}
inline ll power(ll x,ll y){
    ll ans=1;
    while(y){if(y&1)ans=ans*x%mod;x=x*x%mod;y>>=1;}
    return ans;
}
inline ll ni(ll x){return power(x,mod-2);}
inline void MOD(ll &x){x=(x+mod)%mod;}
inline void FWT(ll *a,int tag){
    for(R int i=1;i<size;i<<=1)
        for(R int j=0;j<size;j+=(i<<1))
          for(R int k=0;k<i;++k)MOD(a[i+j+k]+=tag*a[j+k]);
}
int find(int x){return fa[x]=fa[x]==x?x:find(fa[x]);}
int main(){
    n=rd();m=rd();p=rd();int u,v;size=(1<<n);
    for(R int i=1;i<=m;++i){u=rd();v=rd();a[u][v]=a[v][u]=1;}
    for(R int i=1;i<=n;++i)w[i]=rd();
    for(R int i=1;i<size;i<<=1)lowb[i]=lowb[i>>1]+1;
    for(R int s=0;s<size;++s){
       cou[s]=cou[s>>1]+(s&1);bool haha=0;int num=cou[s];
       for(R int i=1;i<=n;++i){if(s&(1<<i-1))sum[s]+=w[i],b[i]=1;else b[i]=0;du[i]=0;fa[i]=i;}
       for(R int i=1;i<=n;++i)if(b[i])for(int j=i+1;j<=n;++j)if(a[i][j]&&b[j]){
         du[i]++,du[j]++;
         int xx=find(i),yy=find(j);if(xx!=yy)fa[xx]=yy,num-=1;
       }
       for(R int i=1;i<=n;++i)if(du[i]&1){haha=1;break;}
       if(num!=1)haha=1;
       if(!haha)continue;
       g[cou[s]][s]=power(sum[s],p);
    }
    for(R int s=0;s<size;++s)ny[s]=ni(power(sum[s],p));
    for(R int i=1;i<=n;++i)FWT(g[i],1);
    f[0][0]=1;FWT(f[0],1);
    for(R int i=1;i<=n;++i){
        for(R int j=0;j<i;++j){
            for(R int k=0;k<size;++k)(f[i][k]+=f[j][k]*g[i-j][k])%=mod;
        }
        FWT(f[i],-1);
        for(R int j=0;j<size;++j)f[i][j]=f[i][j]*ny[j]%mod; 
        if(i<n)FWT(f[i],1);
    }
    printf("%lld",f[n][size-1]);
    return 0;
} 
原文地址:https://www.cnblogs.com/ZH-comld/p/10241368.html