BZOJ 3530 数数

数位dp+AC自动机。dp[i][j]表示考虑到第i位,当前在节点j的方案数。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#define maxn 2050
#define maxv 2050
#define maxe 4050
#define mod 1000000007
using namespace std;
int m,son[maxn][10],fail[maxn],root=0,tot=0,g[maxv],nume=1,dp[maxn][maxn],l,bit[maxn];
char n[maxn],s[maxn];
bool flag[maxn],vis[maxn];
queue <int> q;
struct edge
{
    int v,nxt;
}e[maxe];
void addedge(int u,int v)
{
    e[++nume].v=v;e[nume].nxt=g[u];
    g[u]=nume;
}
void get_bit()
{
    l=strlen(n);
    for (int i=0;i<l;i++) bit[i+1]=n[i]-'0';
}
void insert()
{
    int l=strlen(s),now=root;
    for (int i=0;i<l;i++)
    {
        int nb=s[i]-'0';
        if (!son[now][nb]) son[now][nb]=++tot;
        now=son[now][nb];
    }
    flag[now]=true;
}
void build_AC()
{
    q.push(root);
    while (!q.empty())
    {
        int head=q.front();q.pop();
        for (int i=0;i<=9;i++)
        {
            if (son[head][i])
            {
                if (head!=root) fail[son[head][i]]=son[fail[head]][i];
                else fail[son[head][i]]=root;
                addedge(fail[son[head][i]],son[head][i]);
                q.push(son[head][i]);
            }
            else son[head][i]=son[fail[head]][i];
        }
    }
}
void get_flag()
{
    q.push(root);vis[root]=true;
    while (!q.empty())
    {
        int head=q.front();q.pop();
        for (int i=g[head];i;i=e[i].nxt)
        {
            int v=e[i].v;
            if (vis[v]) continue;
            vis[v]=true;flag[v]|=flag[head];
            q.push(v);
        }
    }
}
int dfs(int now,int pre,bool up)
{
    if (now==l+1) return !flag[pre];
    if ((dp[now][pre]) && (!up)) return dp[now][pre];
    int lim=up?bit[now]:9,ret=0;
    for (int i=0;i<=lim;i++)
    {
        int pos=pre;
        while (pos!=root && !son[pos][i]) pos=fail[pos];
        pos=son[pos][i];
        if (flag[pos]) continue;
        ret=(ret+dfs(now+1,pos,up&&(i==lim)))%mod;
    }
    if (!up) dp[now][pre]=ret;
    return ret;
}
void get_dp()
{
    int ans=0;
    for (int i=1;i<=l-1;i++)
        for (int j=1;j<=9;j++)
        {
            if (flag[son[root][j]]) continue;
            ans=(ans+dfs(l-i+2,son[root][j],0))%mod;
        }
    for (int i=1;i<=bit[1]-1;i++)
    {
        if (flag[son[root][i]]) continue;
        ans=(ans+dfs(2,son[root][i],0))%mod;
    }
    if (!flag[son[root][bit[1]]])
        ans=(ans+dfs(2,son[root][bit[1]],1))%mod;
    printf("%d
",ans);
}
int main()
{
    scanf("%s",n);get_bit();
    scanf("%d",&m);
    for (int i=1;i<=m;i++)
    {
        scanf("%s",s);
        insert();
    }
    build_AC();
    get_flag();
    get_dp();
    return 0;
}
原文地址:https://www.cnblogs.com/ziliuziliu/p/6415859.html