比较恶心的一道题,主要在数位DP那里调试了好久。
题目大意:0~9可以用对应的BCD码表示,给出A~B之间的十进制数,将他们化成BCD码的形式。但有些01串是不能出现的,求能化成BCD码的数的个数。
先建图,然后将二进制转十进制,比如next[p][7]=next[p][0]->[1]->[1]->[1],如果后面这条路能走出来next[p][7]就存在,否则为-1。之后数位DP,求能走出的比X的小的数有多少种,这里要注意细节,很容易写错。求出后DP(B)-DP(A)就行了。
1 #include <stdio.h> 2 #include <string.h> 3 #include <algorithm> 4 #define MAXN 2001 5 #define MOD 1000000009 6 typedef long long LL; 7 int cas,n; 8 char s[205]; 9 int next[MAXN][2],fail[MAXN],flag[MAXN],pos; 10 int nextd[MAXN][10]; 11 int newnode(){ 12 next[pos][0]=next[pos][1]=0; 13 fail[pos]=flag[pos]=0; 14 return pos++; 15 } 16 void insert(char *s){ 17 int p=0; 18 for(int i=0;s[i];i++){ 19 int &x=next[p][s[i]-'0']; 20 p=x?x:x=newnode(); 21 } 22 flag[p]=1; 23 } 24 int q[MAXN],front,rear; 25 void makenext(){ 26 q[front=rear=0]=0,rear++; 27 while(front<rear){ 28 int u=q[front++]; 29 for(int i=0;i<2;i++){ 30 int v=next[u][i]; 31 if(v==0)next[u][i]=next[fail[u]][i]; 32 else q[rear++]=v; 33 if(u&&v)flag[v]|=flag[fail[v]=next[fail[u]][i]]; 34 } 35 } 36 } 37 38 //d[i][j]表示第i个点之后长度为j的单词有多少种 39 //注意只有在非第一位并且无限制时才能记忆化搜索 40 LL d[MAXN][201]; 41 //limit表示是否有限制,first表示该位的下限,第一位为1,其它为0 42 LL dp(int u,int l,int limit,int first){ 43 if(first==0&&limit==0&&d[u][l]!=-1)return d[u][l]; 44 if(l==0)return 1; 45 LL tmp=0; 46 //该位的上限 47 int last=limit?s[l-1]-'0':9; 48 for(int i=first;i<=last;i++){ 49 int v=nextd[u][i]; 50 if(v==-1)continue; 51 tmp+=dp(v,l-1,limit&&last==i,0); 52 if(tmp>MOD)tmp-=MOD; 53 } 54 if(limit==0)d[u][l]=tmp; 55 return tmp; 56 } 57 LL solve(char *s){ 58 //更改字符串顺序便于DP并去掉头部的0 59 int len=strlen(s); 60 for(int i=0;i<len/2;i++)std::swap(s[i],s[len-i-1]); 61 while(len>1&&s[len-1]=='0')len--; 62 LL ans=0; 63 //枚举长度从1~len的串,只有当长度为Len时是有限制的 64 for(int i=len;i>=1;i--){ 65 ans+=dp(0,i,i==len,1); 66 if(ans>MOD)ans-=MOD; 67 } 68 return ans; 69 } 70 //将二进制转换成10进制,nextd表示每个点后可走的10进制数 71 int getnext(int p,int x){ 72 if(flag[p])return -1; 73 for(int i=3;i>=0;i--){ 74 p=next[p][((x>>i)&1)?1:0]; 75 if(flag[p])return -1; 76 } 77 return p; 78 } 79 void makenextd(){ 80 for(int u=0;u<pos;u++) 81 for(int i=0;i<10;i++) 82 nextd[u][i]=getnext(u,i); 83 } 84 int main(){ 85 //freopen("test.in","r",stdin); 86 scanf("%d",&cas); 87 while(cas--){ 88 scanf("%d",&n); 89 pos=0;newnode(); 90 while(n--){ 91 scanf("%s",s); 92 insert(s); 93 } 94 makenext(); 95 makenextd(); 96 scanf("%s",s); 97 98 //找比当前字符串小1的数 99 int len=strlen(s); 100 for(int i=len-1;i>=0;i--){ 101 if(s[i]!='0'){s[i]--;break;} 102 else s[i]='9'; 103 } 104 memset(d,-1,sizeof d); 105 LL ans=-solve(s); 106 scanf("%s",s); 107 ans+=solve(s); 108 printf("%lld\n",(ans%MOD+MOD)%MOD); 109 } 110 return 0; 111 }