[CSP-S模拟测试]:密码(AC自动机+DP)

题目传送门(内部题19)


输入格式

第一行两个正整数$n,k$,代表秘钥个数和要求。
接下来两个正整数$x$和$y$,意义如题所述。
接下来$n$行,每行一个正整数,意义如题所述。


输出格式

一个正整数,代表密码的种数模$1000000007(10^9+7)$的值。


样例

样例输入:

3 1
2 20
2 4 9

样例输出:

6


数据范围与提示

样例解释:

这$6$个密码为$4,9,12,14,19,20$。

数据范围:

设$s=max(x$的长度$,y$的长度$),S=sum$秘钥的长度。
对于$30\%$的数据,$0leqslant x<yleqslant 100,000,1leqslant nleqslant 10,1leqslant kleqslant 5,nleqslant Sleqslant 20$。
对于另外$20\%$的数据,$1leqslant nleqslant 100,1leqslant y−xleqslant 100,000,1leqslant kleqslant 10,nleqslant Sleqslant 200$。
对于另外$10\%$的数据,$n=k=1,1leqslant sleqslant 500$,秘钥为$666$。
对于另外$10\%$的数据,$n=k=1,1leqslant sleqslant 500$,秘钥为$233$。
对于$100\%$的数据,$1leqslant sleqslant 500,1leqslant nleqslant 100,nleqslant Sleqslant 200,1leqslant kleqslant 10$。
$Warning$:可能存在相同的秘钥,应该算作多次。


题解

看出来了$AC$自动机,看出来了数位$DP$,然后你也看到了祖宗……

考虑如何在$AC$自动机上跑数位$DP$,定义$dp[i][j][k][0/1]$表示到了第$i$个数字,在$AC$自动机上的第$j$个节点,已经匹配了$k$个密钥,是否有限制。

时间复杂度:$Theta(40 imes s imes S imes k)$。

期望得分:$100$分。

实际得分:$100$分。


代码时刻

#include<bits/stdc++.h>
using namespace std;
long long n,K;
long long x[600],y[600],a[600],m[600];
char xx[600],yy[600],aa[600];
long long s[600];
long long trie[600][20],ed[600],nxt[600],que[600],cnt;
long long dp[600][500][20][2];
void insert(long long *a)
{
	int p=0;
	for(int i=1;i<=a[0];i++)
	{
		if(!trie[p][a[i]])trie[p][a[i]]=++cnt;
		p=trie[p][a[i]];
	}
	ed[p]++;
}
void build()
{
	int head=1,tail=1;
	while(head<=tail)
	{
		for(int i=0;i<10;i++)
		{
			if(!trie[que[head]][i])continue;
			int flag=nxt[que[head]];
			while(!trie[flag][i]&&flag)flag=nxt[flag];
			que[++tail]=trie[que[head]][i];
			if(trie[flag][i]!=trie[que[head]][i])nxt[que[tail]]=trie[flag][i];
			ed[que[tail]]+=ed[nxt[que[tail]]];
		}
		head++;
	}
}
long long getans1()
{
	memset(dp,0,sizeof(dp));
	s[x[0]+1]=1;
	for(int i=x[0];i;i--)s[i]=(s[i+1]+m[x[0]-i]*x[i]%1000000007)%1000000007;
	dp[0][0][0][1]=1;
	int flag1,flag2=0;
	for(int i=0;i<=x[0];i++)
	{
		for(int j=0;j<=cnt;j++)
		{
			for(int k=0;k<K;k++)
			{
				if(i==x[0])break;
				if(dp[i][j][k][0])
				{
					for(int l=0;l<=9;l++)
					{
						flag1=j;
						while(!trie[flag1][l]&&flag1)flag1=nxt[flag1];
						if(trie[flag1][l])flag1=trie[flag1][l];
						dp[i+1][flag1][min(K,k+ed[flag1])][0]=(dp[i+1][flag1][min(K,k+ed[flag1])][0]+dp[i][j][k][0])%1000000007;
					}
				}
				if(dp[i][j][k][1])
				{
					for(int l=0;l<x[i+1];l++)
					{
						flag1=j;
						while(!trie[flag1][l]&&flag1)flag1=nxt[flag1];
						if(trie[flag1][l])flag1=trie[flag1][l];
						dp[i+1][flag1][min(K,k+ed[flag1])][0]=(dp[i+1][flag1][min(K,k+ed[flag1])][0]+dp[i][j][k][1])%1000000007;
					}
					flag1=j;
					while(!trie[flag1][x[i+1]]&&flag1)flag1=nxt[flag1];
					if(trie[flag1][x[i+1]])flag1=trie[flag1][x[i+1]];
					dp[i+1][flag1][min(K,k+ed[flag1])][1]=(dp[i+1][flag1][min(K,k+ed[flag1])][1]+dp[i][j][k][1])%1000000007;
				}
			}
			flag2=(flag2+dp[i][j][K][0]*m[x[0]-i]%1000000007+dp[i][j][K][1]*s[i+1]%1000000007)%1000000007;
		}
	}
	return flag2;
}
long long getans2()
{
	memset(dp,0,sizeof(dp));
	s[y[0]+1]=1;
	for(int i=y[0];i;i--)s[i]=(s[i+1]+m[y[0]-i]*y[i]%1000000007)%1000000007;
	dp[0][0][0][1]=1;
	int flag1,flag2=0;
	for(int i=0;i<=y[0];i++)
	{
		for(int j=0;j<=cnt;j++)
		{
			for(int k=0;k<K;k++)
			{
				if(i==y[0])break;
				if(dp[i][j][k][0])
				{
					for(int l=0;l<=9;l++)
					{
						flag1=j;
						while(!trie[flag1][l]&&flag1)flag1=nxt[flag1];
						if(trie[flag1][l])flag1=trie[flag1][l];
						dp[i+1][flag1][min(K,k+ed[flag1])][0]=(dp[i+1][flag1][min(K,k+ed[flag1])][0]+dp[i][j][k][0])%1000000007;
					}
				}
				if(dp[i][j][k][1])
				{
					for(int l=0;l<y[i+1];l++)
					{
						flag1=j;
						while(!trie[flag1][l]&&flag1)flag1=nxt[flag1];
						if(trie[flag1][l])flag1=trie[flag1][l];
						dp[i+1][flag1][min(K,k+ed[flag1])][0]=(dp[i+1][flag1][min(K,k+ed[flag1])][0]+dp[i][j][k][1])%1000000007;
					}
					flag1=j;
					while(!trie[flag1][y[i+1]]&&flag1)flag1=nxt[flag1];
					if(trie[flag1][y[i+1]])flag1=trie[flag1][y[i+1]];
					dp[i+1][flag1][min(K,k+ed[flag1])][1]=(dp[i+1][flag1][min(K,k+ed[flag1])][1]+dp[i][j][k][1])%1000000007;
				}
			}
			flag2=(flag2+dp[i][j][K][0]*m[y[0]-i]%1000000007+dp[i][j][K][1]*s[i+1]%1000000007)%1000000007;
		}
	}
	return flag2;
}
int main()
{
	scanf("%lld%lld%s%s",&n,&K,xx+1,yy+1);
	x[0]=strlen(xx+1);
	y[0]=strlen(yy+1);
	m[0]=1;
	for(int i=1;i<=500;i++)m[i]=m[i-1]*10%1000000007;
	for(int i=1;i<=x[0];i++)x[i]=xx[i]-'0';
	for(int i=1;i<=y[0];i++)y[i]=yy[i]-'0';
	for(int i=1;i<=n;i++)
	{
		scanf("%s",aa+1);
		a[0]=strlen(aa+1);
		for(int j=1;j<=a[0];j++)a[j]=aa[j]-'0';
		insert(a);
	}
	build();
	printf("%lld",(getans2()-getans1()+1000000007)%1000000007);
	return 0;
}

rp++

原文地址:https://www.cnblogs.com/wzc521/p/11447192.html