POJ 1625 Censored!(自动机DP+大数相加)

题意:给出包含n个可见字符的字符集,以下所提字符串均由该字符集中的字符构成。给出p个长度不超过10的字符串,求长为m且不包含上述p个字符串的字符串有多少个。

数据范围:1<=n,m<=50,0<=p<=10

状态设计:dp[i][j],i 步之内未经过危险结点且第 i 步到达结点 j 的路径数目。

状态转移:dp[i][j]=∑dp[i-1][k],在结点 k 加输入 s[i] 能跳到结点 j

初始化:dp[0][0]=1,对于其余的 i :dp[0][i]=0

注意:由于最后结果很大,而题中又没提到取模,所以要用到大数相加。

View Code
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <queue>
using namespace std;
#define NODE 110
int next[NODE][50];
int fail[NODE];
bool flag[NODE];
int n,L,m,node;

char ch[51];

int dp[51][NODE][100];

int cmp(const void *a,const void *b)
{
    return *(char*)a-*(char*)b;
}
void init()
{
    node=1;
    memset(next[0],0,sizeof(next[0]));
}
void add(int cur,int k)
{
    memset(next[node],0,sizeof(next[node]));
    flag[node]=0;
    next[cur][k]=node++;
}
int hash(char c)
{
    int min=0,max=n,mid;
    while(min+1!=max)
    {
        mid=min+max>>1;
        if(ch[mid]>c)    max=mid;
        else    min=mid;
    }
    return min;
}
void insert(char *s)
{
    int i,cur,k;
    for(i=cur=0;s[i];i++)
    {
        k=hash(s[i]);
        if(!next[cur][k])   add(cur,k);
        cur=next[cur][k];
    }
    flag[cur]=1;
}
void build_ac()
{
    queue<int>q;
    int cur,nxt,tmp,k;

    fail[0]=0;
    q.push(0);

    while(!q.empty())
    {
        cur=q.front(),q.pop();
        for(k=0;k<n;k++)
        {
            nxt=next[cur][k];
            if(nxt)
            {
                if(!cur)    fail[nxt]=0;
                else
                {
                    for(tmp=fail[cur];tmp&&!next[tmp][k];tmp=fail[tmp]);
                    fail[nxt]=next[tmp][k];
                }
                if(flag[fail[nxt]]) flag[nxt]=1;
                q.push(nxt);
            }
            else    next[cur][k]=next[fail[cur]][k];
        }
    }
}
void ADD(int *a,int *b)
{
    int i,c=0;
    for(i=0;i<100;i++)
    {
        a[i]+=b[i]+c;
        c=a[i]/10;
        a[i]%=10;
    }
}
void solve()
{
    memset(dp,0,sizeof(dp));
    dp[0][0][0]=1;

    for(int step=1;step<=L;step++)
    {
        for(int pre=0;pre<node;pre++)
        {
            if(flag[pre])   continue;
            for(int k=0;k<n;k++)
            {
                int cur=next[pre][k];
                if(flag[cur])   continue;
                ADD(dp[step][cur],dp[step-1][pre]);
            }
        }
    }

    int ans[100],i;
    memset(ans,0,sizeof(ans));
    for(i=0;i<node;i++) if(!flag[i])    ADD(ans,dp[L][i]);

    for(i=99;i>=0 && ans[i]==0;i--);
    if(i<0) puts("0");
    else
    {
        for(;i>=0;i--)  printf("%d",ans[i]);
        puts("");
    }
}
int main()
{
    char s[51];
    while(~scanf("%d%d%d",&n,&L,&m))
    {
        getchar();
        gets(ch);
        qsort(ch,strlen(ch),sizeof(char),cmp);

        init();
        for(int i=0;i<m;i++)
        {
            gets(s);
            insert(s);
        }
        build_ac();
        solve();
    }
    return 0;
}
原文地址:https://www.cnblogs.com/algorithms/p/2628831.html