[LOJ#2540][PKUWC2018]随机算法(概率DP)

场上数据很水,比较暴力的做法都可以过90分以上,下面说几个做法。

1. 暴力枚举所有最大独立集,对每个独立集分别DP。复杂度玄学,但是由于最大独立集并不多,所以可以拿90.

2. dp[S][k]表示考虑到排列的第k位,当前独立集为S的方案数,枚举第k+1位,根据是否与S相连转移到dp[S][k+1]或dp[S | a[k+1]][k+1]。$O(n^22^n)$

3. dp[S]表示排列的状态为S时的正确率,mx[S]表示排列状态为S时能得到的最大独立集大小,考虑转移,枚举排列里最后一个在独立集中的点i∈S,从S中删去所有与i相连的点得到S',若mx[S]<mx[S']+1则更新mx[S],dp[S]清零,否则累加。注意到每个排列都是等概率出现的,所以最后直接除以|S|即可。 $O(n2^n)$

方法一:

 1 #include<cstdio>
 2 #include<cstring>
 3 #include<algorithm>
 4 #define rep(i,l,r) for (int i=(l); i<=(r); i++)
 5 #define ll long long
 6 using namespace std;
 7 
 8 const int N=1<<22,mod=998244353;
 9 ll n,m,x,y,s[25],p[25],f[N][25],cnt,mx,v[N],num[N],t[N],ans,o[N];
10 
11 int main(){
12     freopen("walk.in","r",stdin);
13     freopen("walk.out","w",stdout);
14     scanf("%lld%lld",&n,&m);
15     p[1]=1; rep(i,2,n) p[i]=p[i-1]<<1;
16     rep(i,1,m) scanf("%lld%lld",&x,&y),s[x]|=p[y],s[y]|=p[x];
17     cnt=(1<<n)-1; f[0][0]=1;
18     rep(i,0,cnt){
19         ll tmp=0; v[i]=1;
20         rep(j,1,n) if ((i&p[j])&&(s[j]&i)) v[i]=0;
21         if (v[i]){
22             rep(j,1,n) if (i&p[j]) tmp++,t[i]|=s[j];
23             num[i]=tmp; mx=max(mx,tmp);
24             tmp=0;
25             rep(j,1,n) if (t[i]&p[j]) tmp++;
26             o[i]=tmp;
27         }
28     }
29     rep(i,0,cnt) if (v[i])
30         rep(j,0,o[i]){
31             if (j!=o[i]) f[i][j+1]=(f[i][j+1]+f[i][j]*(o[i]-j))%mod;
32             rep(k,1,n) if (!(i&p[k])&&!(p[k]&t[i])) f[i|p[k]][j]=(f[i|p[k]][j]+f[i][j])%mod;
33             if (num[i]==mx && j==o[i]) ans=(ans+f[i][j])%mod;
34         }
35       printf("%lld
",ans);
36       return 0;
37 }

方法二:

 1 #include<iostream> 
 2 #include<cstdio>
 3 #include<cmath>
 4 #include<cstdlib>
 5 #include<cstring>
 6 #include<algorithm>
 7 using namespace std;
 8 int read()
 9 {
10     int x=0,f=1;char c=getchar();
11     while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();}
12     while (c>='0'&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=getchar();
13     return x*f;
14 }
15 #define P 998244353
16 #define N 21
17 #define t (1<<n)
18 int n,m;
19 long long ans=0;
20 bool flag[1<<(N-1)];
21 int s[1<<(N-1)],w[N],v[1<<(N-1)],cnt[1<<(N-1)],tot[1<<(N-1)],f[21][1<<(N-1)],maximum=1;
22 int main()
23 {
24     freopen("walk.in","r",stdin);
25     freopen("walk.out","w",stdout);
26     n=read(),m=read();
27     for (int i=1;i<=n;i++) w[i]=1<<(i-1),s[w[i]]=w[i];
28     for (int i=1;i<=m;i++) 
29     {
30         int x=read(),y=read();
31         s[w[x]]|=w[y],s[w[y]]|=w[x];
32     }
33     flag[0]=1;
34     for (int i=0;i<t;i++)
35     if (flag[i])
36         for (int j=1;j<=n;j++)
37         if (!(w[j]&s[i])) 
38         {
39             flag[i|w[j]]=1,s[i|w[j]]=s[i]|s[w[j]],cnt[i|w[j]]=cnt[i]+1;
40             if (cnt[i]>=maximum) maximum=cnt[i|w[j]];
41         }
42     for (int i=0;i<t;i++) 
43     {
44         s[i]=(~s[i])&(t-1);
45         register int k=s[i];
46         while (k) k^=k&-k,tot[i]++;
47         v[i]=i&-i;
48     }
49     f[0][0]=1;
50     for (register int i=0;i<n;i++)
51         for (register int j=0;j<t;j++)
52         if (f[i][j]) 
53         {
54             for (register int k=s[j];k;k^=v[k])
55             f[i+1][j|v[k]]=(f[i+1][j|v[k]]+f[i][j])%P;
56             f[i+1][j]=(1ll*f[i][j]*(n-i-tot[j])+f[i+1][j])%P;
57         }
58     for (int i=0;i<t;i++) if (cnt[i]==maximum) ans=(ans+f[n][i])%P;
59     cout<<ans;
60     fclose(stdin);fclose(stdout);
61     return 0;
62 }

方法三:

 1 #include<cstdio>
 2 #include<cstring>
 3 #include<algorithm>
 4 #define rep(i,l,r) for (int i=l; i<=r; i++)
 5 typedef long long ll;
 6 using namespace std;
 7 
 8 const int N=21,mod=998244353;
 9 int n,m,x,y,inv[N],f[N],mx[1<<N],F[1<<N];
10 
11 int main(){
12     scanf("%d%d",&n,&m);
13     rep(i,1,m) scanf("%d%d",&x,&y),x--,y--,f[x]|=1<<y,f[y]|=1<<x;
14     inv[1]=1; f[0]|=1; F[0]=1;
15     rep(i,2,n) f[i-1]|=(1<<(i-1)),inv[i]=1ll*inv[mod%i]*(mod-mod/i)%mod;
16     for (int i=1; i<(1<<n); i++){
17         int tot=0;
18         for (int j=0; j<n; j++) if (i&(1<<j)){
19             int s=i&(~f[j]);
20             if (mx[i]<mx[s]+1) mx[i]=mx[s]+1,F[i]=0;
21             if (mx[i]==mx[s]+1) F[i]=(F[i]+F[s])%mod;
22             tot++;
23         }
24         F[i]=1ll*F[i]*inv[tot]%mod;
25     }
26     printf("%d
",F[(1<<n)-1]);
27     return 0;
28 }
原文地址:https://www.cnblogs.com/HocRiser/p/9059172.html