【POJ2778】DNA Sequence-AC自动机+矩阵优化DP

测试地址:DNA Sequence
题目大意:给定m个DNA序列,问有多少长为n的DNA序列不包含上面的任何一个DNA序列。一个DNA序列指一个仅包含A,C,G,T四种字符的字符串。
做法:本题需要用到AC自动机+矩阵优化DP。
首先不匹配上任意一个字符串,这已经是非常明显的AC自动机+DP了,只需令f(i,j)为要求的DNA序列的前i位匹配到AC自动机上的j点的方案数,直接转移即可。
然而好像有点不对劲,我们发现n非常大,O(n|s|)的复杂度一定会炸,这可怎么办呢?
我们想到我们用AC自动机构造了一个状态转移图,而图中的点数不超过|s|,也就是100,那么我们可以根据这个转移图构造出转移矩阵,这样我们就可以矩阵优化了,时间复杂度降为O(logn(|s|)3),可以通过此题。
以下是本人代码:

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long ll;
const ll mod=100000;
int m,rt=1,tot=1,ch[110][4]={0},fail[110];
int q[110],h,t,len;
ll n;
char s[110];
bool forb[110]={0};
struct matrix
{
    ll mat[110][110];
}S,M;

int f(char c)
{
    if (c=='A') return 0;
    if (c=='C') return 1;
    if (c=='G') return 2;
    if (c=='T') return 3;
}

void insert(int &v,int step)
{
    if (!v) v=++tot;
    if (step>=len) {forb[v]=1;return;}
    insert(ch[v][f(s[step])],step+1);
}

void init()
{
    scanf("%d%lld",&m,&n);
    for(int i=1;i<=m;i++)
    {
        scanf("%s",s);
        len=strlen(s);
        insert(rt,0);
    }
}

void build()
{
    h=t=q[1]=1;
    fail[1]=0;
    while(h<=t)
    {
        int v=q[h++];
        for(int i=0;i<=3;i++)
            if (ch[v][i])
            {
                int p=fail[v];
                while(p&&!ch[p][i]) p=fail[p];
                if (p) fail[ch[v][i]]=ch[p][i];
                else fail[ch[v][i]]=1;
                forb[ch[v][i]]|=forb[fail[ch[v][i]]];
                q[++t]=ch[v][i];
            }
    }
}

void work()
{
    memset(M.mat,0,sizeof(M.mat));
    for(int i=1;i<=tot;i++)
        if (!forb[i])
        {
            for(int j=0;j<=3;j++)
            {
                int p=i;
                while(p&&!ch[p][j]) p=fail[p];
                if (p&&!forb[ch[p][j]]) M.mat[ch[p][j]][i]++;
                else if (!p) M.mat[1][i]++;
            }
        }
}

matrix mult(matrix A,matrix B)
{
    matrix ans;
    memset(ans.mat,0,sizeof(ans.mat));
    for(int i=1;i<=tot;i++)
        for(int j=1;j<=tot;j++)
            for(int k=1;k<=tot;k++)
                ans.mat[i][j]=(ans.mat[i][j]+A.mat[i][k]*B.mat[k][j])%mod;
    return ans;
}

void solve()
{
    memset(S.mat,0,sizeof(S.mat));
    for(int i=1;i<=tot;i++)
        S.mat[i][i]=1;
    while(n)
    {
        if (n&1) S=mult(S,M);
        M=mult(M,M);n>>=1;
    }
    ll ans=0;
    for(int i=1;i<=tot;i++)
        ans=(ans+S.mat[i][1])%mod;
    printf("%lld",ans);
}

int main()
{
    init();
    build();
    work();
    solve();

    return 0;
}
原文地址:https://www.cnblogs.com/Maxwei-wzj/p/9793483.html