ZOJ 3494 BCD Code (AC自动机+数位DP)

题意:给n个01字符串,求区间[x , y]中有多少个数写成BCD码后不包含以上01串。

分析:先用01字符串建立AC自动机(注意标记危险结点),然后DP。dp[i][s]表示扫描前i位后有多少个数会到达自动机的结点s.

TLE:数位DP写搓了……(为此今天专门学习了下数位DP的dfs写法,发现确实比递推的快不少)

WA:1、高精度减1写错了;  2、由于有取模,所以cal(y)可能小于cal(x-1),输出的时候没考虑到。

View Code
#include <stdio.h>
#include <string.h>
#include <queue>
using namespace std;

#define NODE 2002
#define LEN 202
#define MOD 1000000009

int node;
int next[NODE][2],fail[NODE];
bool flag[NODE];

char A[LEN],B[LEN];
int digit[LEN];

int dp[LEN][NODE];

int newnode()
{
    fail[node]=0;
    flag[node]=0;
    memset(next[node],0,sizeof(next[0]));
    return node++;
}
void insert(char *s)
{
    int i,k,cur;
    for(i=cur=0; s[i]; i++)
    {
        k=s[i]-'0';
        if(!next[cur][k])   next[cur][k]=newnode();
        cur=next[cur][k];
    }
    flag[cur]=1;
}
void makenext()
{
    int u,v;
    queue<int>q;

    q.push(0);
    while(!q.empty())
    {
        u=q.front(),q.pop();
        for(int k=0; k<2; k++)
        {
            v=next[u][k];
            if(v)   q.push(v);
            else    next[u][k]=next[fail[u]][k];
            if(u&&v)
            {
                fail[v]=next[fail[u]][k];
                flag[v] |=flag[fail[v]];
            }
        }
    }
}
int nstate(int s,int x)
{
    for(int t=8;t;t>>=1)
    {
        if(x&t) s=next[s][1];
        else    s=next[s][0];
        if(flag[s]) return -1;
    }
    return s;
}
int dfs(int pos,int s,int z,int f)
{
    if(pos==-1) return 1;
    if(z&&!f&&dp[pos][s]!=-1)   return dp[pos][s];
    int max=f?digit[pos]:9;
    int ret=0;
    for(int i=0;i<=max;i++)
    {
        int ns;
        if(z||i)    ns=nstate(s,i);
        else    ns=s;
        if(ns!=-1)
        {
            ret+=dfs(pos-1,ns,z||i,f&&i==max);
            ret%=MOD;
        }
    }
    if(z&&!f)   dp[pos][s]=ret;
    return ret;
}
int cal(char *s)
{
    int pos=0;
    int len=strlen(s+1);
    for(int i=len;i;i--)    digit[pos++]=s[i]-'0';
    return dfs(pos-1,0,0,1);
}
void sub(char *s)
{
    int i,len=strlen(s+1);
    for(i=len;i>=1&&s[i]=='0';i--)  s[i]='9';
    s[i]--;
}
int main()
{
    int t,n;
    char s[22];
    scanf("%d",&t);
    while(t--)
    {
        node=0,newnode();
        scanf("%d",&n);
        while(n--)
        {
            scanf("%s",s);
            insert(s);
        }
        makenext();
        scanf("%s%s",A+1,B+1);
        int la=strlen(A+1);
        int lb=strlen(B+1);
        int l=la>lb?la:lb;
        for(int i=0;i<l;i++)
            memset(dp[i],-1,sizeof(dp[0][0])*node);
        sub(A);
        printf("%d\n",(cal(B)-cal(A)+MOD)%MOD);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/algorithms/p/2667660.html