POJ2778 DNA Sequence [AC自动机+矩阵]

   又是一道调了大半天的题,最后发现竟然是自己建立trie图的地方有个小BUG,这个小BUG在字符串匹配时没什么影响,所以一直没发现出来。刚刚学习,还是理解的不够深入啊。现在这个trie图应该算是写的很简洁了,可以拿来当模版了。

  题意很简单,就是问长度为n不包含若干子串的串一共有多少个。这里可以用AC自动机DP,首先对于单词的结尾节点,标记为非法节点,一旦走到了一个非法节点,就说明包含了某个单词。网上的解题报告很多人说是AC自动机DP。但这题我的做法似乎没有用到DP,只是用矩阵加速了一下罢了。。。。首先标记出非法节点,补全trie图,用一个矩阵表示从每个合法节点到其它合法节点转移的方案,可以表示为一个邻接矩阵M,求第N步从根节点到其它节点有多少方案,这个问题就是很简单的路径方案数问题,只要求M^N,然后求sum(M[1][1]...[1][N])就行了。。。。

  需要注意的是,在建立trie图的过程中,如果一个节点的失败指针指向了另一个非法节点,则说明这个节点也是个非法节点。比如对于字符串{AG,CAGT},第一个字符串中的G和第二个字符串中的T很显然是非法节点。但是在第二个字符串中走到G时,实际上已经包含了AG这个字符串,它的失败指针指向AG中的G,所以这个G也是一个非法节点。

  调了大半天终于AC了,稍加优化的程序跑了32ms,居然刷进了第一版。。0ms排在第一的竟然是ZFY学长,YM。。。

#include <string.h>
#include <stdio.h>
#include <queue>
#define MAXN 110
#define MOD 100000
typedef long long LL;
LL dmat[MAXN][MAXN];
struct matrix{
    LL mz[MAXN][MAXN];int n;
    #define FOR(i) for(int i=1;i<=n;i++)
    //初始化矩阵,空矩阵,单位矩阵和dmat矩阵
    matrix(int nn,int type):n(nn){
        if(type==0)FOR(i)FOR(j)mz[i][j]=0;
        else if(type==1)FOR(i)FOR(j)mz[i][j]=(i==j)?1:0;
        else FOR(i)FOR(j)mz[i][j]=dmat[i][j];
    }
    //重载矩阵乘法,10^5*10^5*100不会超longlong的,最后一次性模就可以了,模是很费时的
    matrix operator *(const matrix& b)const{
        matrix ans(n,0);
        FOR(i)FOR(j)if(mz[i][j])
            FOR(k)ans.mz[i][k]+=mz[i][j]*b.mz[j][k];
        FOR(i)FOR(j)if(ans.mz[i][j]>MOD)ans.mz[i][j]%=MOD;
        return ans;
    }
    //二分矩阵乘法
    matrix binMat(int x){
        matrix ans(n,1),tmp(n,2);
        for(;x;tmp=tmp*tmp,x>>=1){
            if(x&1)ans=ans*tmp;
        }
        return ans;
    }
};
int n,m;
char s[12];
int next[MAXN][4],fail[MAXN],flag[MAXN],id[MAXN],ids,pos;
int trans(char c){
    if(c=='A')return 0;
    if(c=='C')return 1;
    if(c=='T')return 2;
    return 3;
}
int newnode(){
    for(int i=0;i<4;i++)next[pos][i]=0;
    fail[pos]=flag[pos]=id[pos]=0;
    return pos++;
}
void insert(char *s){
    int p=0,len=strlen(s);
    for(int i=0;i<len;i++){
        int &x=next[p][trans(s[i])];
        p=x?x:x=newnode();
    }
    flag[p]=1;
}
int q[MAXN],front,rear;
void makenext(){
    q[front=rear=0]=0,rear++;
    while(front<rear){
        int u=q[front++];
        for(int i=0;i<4;i++){
            int v=next[u][i];
            if(flag[v])continue;
            if(v==0)next[u][i]=next[fail[u]][i];
            else q[rear++]=v;
            //这个地方忘了判断v是否是0了,调了很久...省代码还是要小心啊..
            if(v&&u){
                fail[v]=next[fail[u]][i];
                //如果指向一个非法节点,那这个节点也是一个非法节点(比如cg和acgt这样的串,第二个串中的g也是非法的)
                if(flag[fail[v]])flag[v]=1;
            }
        }
    }
}
int main(){
    while(scanf("%d%d",&m,&n)!=EOF){
        pos=ids=0;newnode();
        memset(dmat,0,sizeof dmat);
        for(int i=0;i<m;i++){
            scanf("%s",s);
            insert(s);
        }
        makenext();
        //建立矩阵,从每个合法节点到另一个节点转移的方案数,类似于邻接矩阵
        for(int u=0;u<pos;u++){
            if(flag[u])continue;
            for(int i=0;i<4;i++){
                int v=next[u][i];
                if(flag[v])continue;
                if(id[u+1]==0)id[u+1]=++ids;
                if(id[v+1]==0)id[v+1]=++ids;
                dmat[id[u+1]][id[v+1]]++;
            }
        }
        matrix mt=matrix(ids,2).binMat(n);

        LL ans=0;
        for(int i=1;i<=mt.n;i++)
            ans+=mt.mz[1][i];
        ans%=MOD;
        printf("%lld\n",ans);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/swm8023/p/2625297.html