POJ1625 Censored! [AC自动机+DP]

  应该算是最基础的AC自动机DP了吧。。

  跟前面做的两道用矩阵加速的AC自动机题目意思差不多,都是求不包含给定单词的单词数。区别就是给定单词较多,需要的字符串比较短,然后结果不取模,要用到高精度。

  trie图中大约有50*10个节点,如果建立矩阵用矩阵加速无论是时间复杂度还是空间复杂度都是会超的(时间大约是500^3*log(50),空间是500^2*高精度数组)。所以这里要用到DP了,用dp[i][j]表示第i步在第j个节点的方法数。flag[k]=0表示不是非法节点,son是k的儿子节点的集合。

  

  一开始爆空间了,POJ真心抠啊,这题就给了10M内存,后来高精度里面用char存,然后再用上滚动数组,瞬间就200K内存了=。=

  不过效率还是好低,要跑400ms,不知道是我的高精度太慢了还是DP写的太搓了,感觉我这DP写的很有bellman_ford的感觉啊,应该是还能优化的吧。。

#include <string.h>
#include <stdio.h>
#include <queue>
#include <algorithm>
#include <map>
#define MAXN 505
struct bign{
    char s[100];int len;
    bign operator =(const char *ss){
        len=strlen(ss);
        for(int i=0;i<len;i++)s[i]=ss[len-i-1]-'0';
        return *this;
    }
    void print(){
        for(int i=len-1;i>=0;i--)printf("%d",s[i]);printf("\n");
    }
    bign operator +(const bign& b)const {
        bign c;
        c.len=0;
        for(int i=0,g=0;g||i<len||i<b.len;i++){
            if(i<len)g+=s[i];
            if(i<b.len)g+=b.s[i];
            c.s[c.len++]=g%10,g/=10;
        }
        return c;
    }
}d[2][MAXN],ans;
char s[100];
int n,m,ps,dy[500];
int next[MAXN][55],fail[MAXN],flag[MAXN],pos;
int newnode(){
    for(int i=0;i<m;i++)next[pos][i]=0;
    fail[pos]=flag[pos]=0;
    return pos++;
}
void insert(char *s){
    int p=0;
    for(int i=0;s[i];i++){
        int k=dy[s[i]+128],&x=next[p][k];
        p=x?x:x=newnode();
    }
    flag[p]=1;
}
void makenext(){
    std::queue<int> q;
    q.push(0);
    while(!q.empty()){
        int u=q.front();q.pop();
        for(int i=0;i<n;i++){
            int v=next[u][i];
            if(v==0)next[u][i]=next[fail[u]][i];
            else q.push(v);
            if(u&&v){
                fail[v]=next[fail[u]][i];
                if(flag[fail[v]])flag[v]=1;
            }
        }
    }
}
void dp(){
    for(int j=0;j<pos;j++)d[0][j]="0";
    d[0][0]="1",ans="0";
    int cur=0;
    for(int i=0;i<m;i++){
        cur^=1;
        for(int j=0;j<pos;j++)d[cur][j]="0";
        for(int u=0;u<pos;u++){
            if(flag[u])continue;
            for(int k=0;k<n;k++){
                int v=next[u][k];
                if(flag[v])continue;
                d[cur][v]=d[cur][v]+d[cur^1][u];
            }
        }
    }
    for(int i=0;i<pos;i++)ans=ans+d[cur][i];
}
int main(){
    //freopen("test.in","r",stdin);
    while(scanf("%d%d%d",&n,&m,&ps)!=EOF){
        gets(s);
        gets(s);
        pos=0;newnode();
        for(int i=0;i<n;i++)dy[s[i]+128]=i;
        for(int i=0;i<ps;i++){
            gets(s);
            insert(s);
        }
        makenext();
        dp();
        ans.print();
    }
    return 0;
}

  

  

原文地址:https://www.cnblogs.com/swm8023/p/2626242.html