BZOJ 1444: [Jsoi2009]有趣的游戏 [AC自动机 高斯消元]

1444: [Jsoi2009]有趣的游戏

题意:每种字母出现概率(p_i),有一些长度len的字符串,求他们出现的概率


套路DP的话,(f[i][j]) i个字符走到节点j的概率,建出转移矩阵来矩乘几十次可以认为是无穷个字符,就得到概率了

但我们发现Trie图也是图啊,直接高斯消元就好了,(f[i])表示走到节点i的期望次数

注意(f[0])需要+1

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
typedef long long ll;
const int N=105;
const double eps=1e-8;
inline int read(){
    char c=getchar();int x=0,f=1;
    while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}
    return x*f;
}

int n, len, m, pos[N]; double p[N], x, y;
char s[N];
namespace ac{
	struct meow{int ch[11], fail, val;} t[N];
	int sz;
	void insert(char *s, int id) {
		int u=0;
		for(int i=1; i<=len; i++) {
			int c=s[i]-'A';
			if(!t[u].ch[c]) t[u].ch[c] = ++sz;
			u=t[u].ch[c]; 
		}
		t[u].val=1; 
		pos[id]=u;
	}

	int q[N], head, tail;
	void build() {
		head=tail=1;
		for(int i=0; i<m; i++) if(t[0].ch[i]) q[tail++] = t[0].ch[i];
		while(head!=tail) {
			int u=q[head++];
			t[u].val |= t[t[u].fail].val;
			for(int i=0; i<m; i++) {
				int &v = t[u].ch[i];
				if(!v) v = t[t[u].fail].ch[i];
				else t[v].fail = t[t[u].fail].ch[i], q[tail++]=v;
			}
		}
	}
}using ac::t; using ac::sz;

double a[N][N];
namespace eq{
	void build() { 
		a[0][sz+1] = 1;
		for(int i=0; i<=sz; i++) {  //printf("i %d
",i);
			a[i][i]=1;
			if(!t[i].val) for(int j=0; j<m; j++) 
				a[ t[i].ch[j] ][i] -= p[j];// printf("ch %d %lf  %d
",j,p[j],t[i].ch[j]);
		}
		//for(int i=0; i<=n; i++) for(int j=0; j<=n+1; j++) printf("%lf%c",a[i][j],j==n+1?'
':' ');
	}

	void gauss(int n) {
		for(int i=0; i<=n; i++) {
			int r=i;
			for(int j=i; j<=n; j++) if(abs(a[j][i])>abs(a[r][i])) r=j;
			if(r!=i) for(int j=0; j<=n+1; j++) swap(a[r][j], a[i][j]);

			for(int k=i+1; k<=n; k++) {
				double t = a[k][i]/a[i][i];
				for(int j=i; j<=n+1; j++) a[k][j] -= t*a[i][j];
			}
		}
		for(int i=n; i>=0; i--) {
			for(int j=n; j>i; j--) a[i][n+1] -= a[i][j]*a[j][n+1];
			a[i][n+1] /= a[i][i];
		}
	}
}
int main() {
	freopen("in","r",stdin);
	n=read(); len=read(); m=read(); 
	int flag=0;
	for(int i=0; i<m; i++) x=read(), y=read(), p[i]=(double)x/y, flag |= p[i]>eps;
	if(!flag) {for(int i=1; i<=n; i++) puts("0.00"); return 0;}

	for(int i=1; i<=n; i++) scanf("%s",s+1), ac::insert(s, i);
	ac::build();
	eq::build(); eq::gauss(sz);
	//for(int i=1; i<=n; i++) printf("%d ",pos[i]); puts(" pos");
	for(int i=1; i<=n; i++) printf("%.2lf
", a[pos[i]][sz+1]);
}

原文地址:https://www.cnblogs.com/candy99/p/6666722.html