【HDOJ5955】Guessing the Dice Roll(概率DP,AC自动机,高斯消元)

题意:

有n个人,每个人有一个长为L的由1~6组成的数串,现在扔一个骰子,依次记录扔出的数字,如果当前扔出的最后L个数字与某个人的数串匹配,那么这个人就算获胜,现在问每个人获胜的概率是多少。

n,l<=10

思路:对于无限型的概率

首先显然有一个暴力做法是对于n个串建出AC自动机和转移矩阵后跑若干次矩乘快速幂DP使得答案趋于稳定后可以将结果看做答案

正解是高斯消元

每个点对于它的后继有1/6的概率跑到,计算贡献后累加

边界条件:题目给出的n个串没有后继

游戏开始时必定会转移到根节点,等价于有一个虚的节点(不需要建立方程)对根有100%的1的贡献

剩下的就是高斯消元模板了

  1 #include<cstdio>
  2 #include<cstring>
  3 #include<string>
  4 #include<cmath>
  5 #include<iostream>
  6 #include<algorithm>
  7 #include<map>
  8 #include<set>
  9 #include<queue>
 10 #include<vector>
 11 #include<bits/stdc++.h>
 12 using namespace std;
 13 typedef long long ll;
 14 typedef unsigned int uint;
 15 typedef unsigned long long ull;
 16 typedef pair<int,int> PII;
 17 typedef vector<int> VI;
 18 #define fi first
 19 #define se second 
 20 #define MP make_pair
 21 #define N  11000
 22 #define M  210
 23 #define MOD 1000000007
 24 #define eps 1e-8 
 25 #define pi acos(-1)
 26 
 27 double a[M][M];
 28 int nxt[M][7],fa[M],c[M],d[N],q[N],b[N],flag[N],n,l,cnt;
 29 
 30 int read()
 31 { 
 32    int v=0,f=1;
 33    char c=getchar();
 34    while(c<48||57<c) {if(c=='-') f=-1; c=getchar();}
 35    while(48<=c&&c<=57) v=(v<<3)+v+v+c-48,c=getchar();
 36    return v*f;
 37 }
 38 
 39 void build(int k)
 40 {
 41     int u=1;
 42     for(int i=1;i<=l;i++)
 43     {
 44         if(!nxt[u][b[i]]) nxt[u][b[i]]=++cnt;
 45         u=nxt[u][b[i]];
 46     }
 47     c[k]=u;
 48     d[u]=k;
 49 }
 50             
 51 void acauto()
 52 {
 53     int t=0; int w=1; q[1]=1;
 54     while(t<w)
 55     {
 56         int u=q[++t];
 57     //    printf("%d
",u); 
 58         for(int i=1;i<=6;i++)
 59         {
 60              if(nxt[u][i])
 61               {
 62                  int son=nxt[u][i];        
 63                  int p=fa[u];
 64                  if(u==1) fa[son]=1;
 65                   else fa[son]=nxt[p][i];
 66                  q[++w]=son;
 67              }
 68               else
 69               {
 70                   int p=fa[u];
 71                 if(u==1) nxt[u][i]=1;
 72                  else nxt[u][i]=nxt[p][i];
 73                }
 74         }
 75     }
 76 }
 77 
 78 void init()
 79 {
 80 /*    for(int i=1;i<=cnt;i++)
 81     {
 82         fa[i]=flag[i]=d[i]=0;
 83         for(int j=1;i<=6;j++) nxt[i][j]=0;
 84     }
 85     memset(q,0,sizeof(q));
 86     for(int i=1;i<=cnt;i++)
 87      for(int j=0;j<=cnt;j++) a[i][j]=0;
 88     for(int i=1;i<=n;i++) c[i]=0;*/
 89     memset(a,0,sizeof(a));
 90     memset(nxt,0,sizeof(nxt));
 91     memset(fa,0,sizeof(fa));
 92     memset(c,0,sizeof(c));
 93     memset(d,0,sizeof(d));
 94     //memset(q,0,sizeof(q));
 95     memset(b,0,sizeof(b));
 96     cnt=1;
 97 }
 98 
 99 void test()
100 {
101     for(int i=1;i<=cnt;i++)
102     {
103         for(int j=1;j<=cnt+1;j++) printf("%.2lf ",a[i][j]);
104         printf("
");
105     }
106     printf("
");
107 }
108 
109 int main()
110 {
111     freopen("hdoj5955.in","r",stdin);
112     freopen("hdoj5955.out","w",stdout);
113     int cas;
114     scanf("%d",&cas);
115     while(cas--)
116     {
117         init();
118         scanf("%d%d",&n,&l);
119         for(int i=1;i<=n;i++)
120         {
121             for(int j=1;j<=l;j++) scanf("%d",&b[j]);
122             build(i);
123         }
124         
125         acauto();
126         
127         for(int i=1;i<=cnt;i++)
128         {
129             a[i][cnt+1]=0; 
130             a[i][i]=-1.0;
131             if(d[i]) continue;
132             for(int j=1;j<=6;j++)
133             {
134                 int v=nxt[i][j];
135                 a[v][i]+=1.0/6.0;
136             //    printf("%d %d
",i,v);
137             }
138         }
139         
140         a[1][cnt+1]=-1.0;
141         
142     //    test();
143 
144         for(int i=1;i<=cnt;i++)
145         {
146             int K=i;
147             for(int j=i+1;j<=cnt;j++)
148              if(fabs(a[j][i])>fabs(a[K][i])) K=j;
149              
150             if(K!=i)
151              for(int j=1;j<=cnt+1;j++) swap(a[i][j],a[K][j]);
152             
153             for(int j=i+1;j<=cnt;j++)
154             {
155                 double f=a[j][i]/a[i][i];
156                 for(int k=i;k<=cnt+1;k++) a[j][k]-=f*a[i][k];
157             } 
158         
159             
160         }
161         
162         for(int i=cnt;i>=1;i--)
163         {
164              for(int j=i+1;j<=cnt;j++) 
165               a[i][cnt+1]-=a[i][j]*a[j][cnt+1];
166              a[i][cnt+1]=a[i][cnt+1]/a[i][i];
167         }
168         
169     //    test();
170         for(int i=1;i<=n;i++) 
171          if(i<n) printf("%.6lf ",a[c[i]][cnt+1]);
172           else printf("%.6lf
",a[c[i]][cnt+1]);
173         
174     }
175 
176     return 0;
177 }
178      
原文地址:https://www.cnblogs.com/myx12345/p/9743856.html