题解 CF585F 【Digits of Number Pi】

考虑用数位 (DP) 来统计数字串个数,用 (SAM) 来实现子串的匹配。

设状态 (f(pos,cur,lenth,lim,flag)),表示数位的位数,在 (SAM) 上的节点,匹配的长度,是否有最高位限制,是否已经满足要求。

(dfs) 转移时,若当前节点能接着匹配枚举到的字符,就直接转移,若不能,则在 (Parent) 树上向上跳,直到能接着匹配,转移过程中判断是否满足条件即可。

(code:)

#include<bits/stdc++.h>
#define maxn 2010
#define maxd 55
#define mod 1000000007
using namespace std;
template<typename T> inline void read(T &x)
{
    x=0;char c=getchar();bool flag=false;
    while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    if(flag)x=-x;
}
int n,d,tot=1,las=1,root=1,cnt;
int fa[maxn],len[maxn],ch[maxn][12],num[maxn],f[maxd][maxn][maxd][2][2];
char s[maxn],a[maxn],b[maxn];
void insert(int c)
{
    int p=las,np=las=++tot;
    len[np]=len[p]+1;
    while(p&&!ch[p][c]) ch[p][c]=np,p=fa[p];
    if(!p) fa[np]=root;
    else
    {
        int q=ch[p][c];
	if(len[q]==len[p]+1) fa[np]=q;
	else
	{
	    int nq=++tot;
	    memcpy(ch[nq],ch[q],sizeof(ch[q]));
	    len[nq]=len[p]+1,fa[nq]=fa[q],fa[np]=fa[q]=nq;
	    while(ch[p][c]==q) ch[p][c]=nq,p=fa[p];
	}
    }
}
int dp(int pos,int cur,int lenth,bool lim,bool flag)
{
    if(pos==cnt+1) return flag;
    if(f[pos][cur][lenth][lim][flag]!=-1) return f[pos][cur][lenth][lim][flag];
    int v=0,ma=9;
    if(lim) ma=num[pos];
    for(int c=0;c<=ma;++c)
    {
        if(flag) v=(v+dp(pos+1,cur,lenth,lim&&c==ma,1))%mod;
        else
        {
            int p=cur;
            if(ch[p][c]) v=(v+dp(pos+1,ch[p][c],lenth+1,lim&&c==ma,lenth+1>=d/2))%mod;
            else
            {
                while(p&&!ch[p][c]) p=fa[p];
		if(p) v=(v+dp(pos+1,ch[p][c],len[p]+1,lim&&c==ma,len[p]+1>=d/2))%mod;
		else v=(v+dp(pos+1,root,0,lim&&c==ma,0))%mod;
            }
        }
    }
    return f[pos][cur][lenth][lim][flag]=v;
}
int solve(char *s)
{
    cnt=strlen(s+1),memset(f,-1,sizeof(f));
    for(int i=1;i<=cnt;++i) num[i]=s[i]-'0';
    return dp(1,root,0,1,0);
}
int main()
{
    scanf("%s%s%s",s+1,a+1,b+1),n=strlen(s+1),d=cnt=strlen(a+1),a[cnt--]--;
    for(int i=1;i<=n;++i) insert(s[i]-'0');
    while(a[cnt]-'0'==-1) a[cnt]--,a[cnt--]=9+'0';
    printf("%d",(solve(b)-solve(a)+mod)%mod);
    return 0;
}
原文地址:https://www.cnblogs.com/lhm-/p/13303511.html