bzoj2553

$AC自动机+矩阵快速幂$

$多串统计问题...上AC自动机$

$看见len很大,那么我们就得用矩阵乘法$

$关于dp,题目中希望尽量多地匹配,AC自动机的匹配过程满足了这个条件,所以转移按AC自动机匹配的顺序转移就行了$

$如果到了一个终止节点,那么我们要统计,这里额外增加一个节点专门统计答案$

#include<bits/stdc++.h>
using namespace std;
const int N = 105;
int n, len, m;
char s[N];
struct node {
    int fail, p;
    int ch[26];
} t[N];
int cnt, root;
struct matrix {
    long double a[N][N];
    matrix() { for(int i = 0; i <= cnt; ++i) for(int j = 0; j <= cnt; ++j) a[i][j] = 0; }
    void set() {
        for(int i = 0; i <= cnt; ++i) {
            a[i][i] = 1;
        }
    }
    matrix friend operator * (const matrix &a, const matrix &b) {
        matrix t;
        for(int i = 0; i <= cnt; ++i) {
            for(int j = 0; j <= cnt; ++j) {
                for(int k = 0; k <= cnt; ++k) {
                    t.a[i][j] += a.a[i][k] * b.a[k][j];
                }
            }
        }
        return t;
    }
} a, b;
void insert(char *s) {
    int len = strlen(s), now = root;
    for(int i = 0; i < len; ++i) {
        int c = s[i] - 'a';
        if(!t[now].ch[c]) {
            t[now].ch[c] = ++cnt;   
        }
        now = t[now].ch[c];     
    }
    t[now].p |= 1;
}
void build() {
    queue<int> q;
    for(int i = 0; i < m; ++i) {
        if(t[root].ch[i]) {
            q.push(t[root].ch[i]);
        }
    }
    while(!q.empty()) {
        int u = q.front();
        q.pop();
        for(int i = 0; i < m; ++i) {
            int v = t[u].ch[i];
            if(!v) {
                t[u].ch[i] = t[t[u].fail].ch[i];
            } else {
                t[v].fail = t[t[u].fail].ch[i];
                t[v].p |= t[t[v].fail].p;
                q.push(v);
            }
        }
    }
    ++cnt;
    a.a[cnt][cnt] = 1;
    for(int i = 0; i < cnt; ++i) {
        for(int j = 0; j < m; ++j) {
            if(t[t[i].ch[j]].p) {
                a.a[i][cnt] += 1.0 / m;
                a.a[i][0] += 1.0 / m;
            } else {
                a.a[i][t[i].ch[j]] += 1.0 / m;  
            }
        }
    }
}
int main() {
    scanf("%d%d%d", &n, &len, &m);
    for(int i = 1; i <= n; ++i) {
        scanf("%s", s);
        insert(s);
    }
    build();
    for(int i = 0; i <= cnt; ++i) b.a[i][i] = 1;
    for(; len; len >>= 1, a = a * a) {
        if(len & 1) {
            b = b * a;
        }
    }
    printf("%.7f
", (double)b.a[0][cnt]);
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/19992147orz/p/8379256.html