bzoj1009

AC自动机+dp+矩阵乘法

我们先对串建立AC自动机,然后进行dp+矩阵乘法。AC自动机加上trie图优化,root的每个儿子如果没有就都填上,然后建立矩阵,mat[child[u][i]][u]=1,如果在trie图上child[u][i]是u的儿子,并且u和儿子都不是危险节点,然后初始值是dp[1][child[root][i]]=1,然后系数矩阵自乘n-1次,再乘上dp矩阵就行了,答案是dp[n][1->cnt]。这样避免了通过危险串,并且因为root的每个儿子都能走,并且因为trie图优化,所以每种情况都能走到。

#include<bits/stdc++.h>
using namespace std;
const int N = 110;
int n, m, mod, root, cnt, answer;
int child[N][11], fail[N], danger[N];
char s[N];
struct mat {
    int a[N][N];
    mat friend operator * (mat A, mat B)
    {
        mat ret;
        memset(ret.a, 0, sizeof(ret.a));
        for(int k = 1; k <= cnt; ++k)
            for(int i = 1; i <= cnt; ++i)
                for(int j = 1; j <= cnt; ++j)
                    ret.a[i][j] = (ret.a[i][j] + A.a[i][k] * B.a[k][j]) % mod;
        return ret;                 
    }
} A, ans;
void insert(char s[])
{
    int now = root;
    for(int i = 0; i < m; ++i)
    {
        if(child[now][s[i] - '0'] == 0) child[now][s[i] - '0'] = ++cnt;
        if(i != m - 1) A.a[child[now][s[i] - '0']][now] = 1;
        now = child[now][s[i] - '0'];
    }
    danger[now] = 1;
}
void construct_fail()
{
    queue<int> q;
    for(int i = 0; i <= 9; ++i) if(child[root][i]) q.push(child[root][i]);
    while(!q.empty())
    {
        int u = q.front();
        q.pop();
        if(danger[u]) continue;
        for(int i = 0; i <= 9; ++i)
        {
            int &v = child[u][i];
            if(v == 0) 
            {
                v = child[fail[u]][i];
                if(danger[v] == 0) A.a[v][u] = 1;
            }
            else
            {
                fail[v] = child[fail[u]][i];
                danger[v] |= danger[fail[v]];
                q.push(v);
            } 
        }
    }
}
int main()
{
    scanf("%d%d%d%s", &n, &m, &mod, s);
    insert(s);
    for(int i = 0; i <= 9; ++i) if(child[root][i] == 0) child[root][i] = ++cnt;
    construct_fail();
    for(int i = 0; i <= 9; ++i) ans.a[child[root][i]][1] = 1;
    for(int t = n - 1; t; t >>= 1, A = A * A) if(t & 1) ans = A * ans;
    for(int i = 1; i <= cnt; ++i) answer = (answer + ans.a[i][1]) % mod;
    printf("%d
", answer);
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/19992147orz/p/7363437.html