Aho-Corasick

Wireless Password

题意:

给m(m<=10)个模板串,问长度为n(<=25)的字符串至少包含k种模板串的总的种类数。

0.首先出现了多个考虑Aho-Corasick。

1.本题模板串的个数是小于10的,所以可以将这些模板串状态压缩,在建立fail指针的时候,将这颗Tirie树联通好,那么就只需要进行状态转移。

2.状态定义dp[len][i][s0],表示目前字符串的长度为n位于i号节点状态为s0的总方案数

3.状态转移:dp[len+1][j][s0|s1] += dp[len][i][s0];

4.最后只需要查出长度为n,状态数多于等于k的dp值并求和。

#include <cstdio>
#include <iostream>
#include <queue>
#include <cstring>
#include <algorithm>
#define LL long long
using namespace std;

//Globel define
const int N = 30;
const int alp = 26;
const int mod = 20090717;
char buf[N];
int n,m,k;
int ncnt;
int dp[N][110][1035];
struct node{
	int i,S;
	node *ch[alp],*fail;
	void init(){
		S = 0;
		for(int i = 0;i < alp;++i)ch[i] = NULL;
	}
}trie[110];
//end Globel define


node *newnode(){
	node *p = &trie[ncnt];
	p->init();
	p->i = ncnt++;
	return p;
}

void insert(node *root,char *s,int i){
	node *p = root;
	int S = 0;
	while(*s != ''){
		if(!p->ch[*s-'a'])p->ch[*s-'a'] = newnode();
		p = p->ch[*s-'a'];
		++s;
	}
	p->S |= (1<<i);
}

void buildfail(node *root){
	queue <node *> q;
	root->fail = NULL;
	q.push(root);
	while(!q.empty()){
		node *p = q.front();q.pop();
		for(int i = 0;i < alp;++i){
			if(p->ch[i]){
				node *next = p->fail;
				while(next && !next->ch[i])next = next->fail;
				p->ch[i]->fail = next ? next->ch[i] : root;
				p->ch[i]->S |= p->ch[i]->fail->S;
				q.push(p->ch[i]);
			}
			else p->ch[i] = (p==root) ? root:p->fail->ch[i];
		}
	}
}

int count(int S){
	int cnt = 0;
	for(int i = 0;i < 10;++i)if(S&(1<<i))cnt++;
	return cnt;
}

int main(){
	while(scanf("%d%d%d",&n,&m,&k) && n+m+k){
		ncnt = 0;
		memset(dp,0,sizeof(dp));
		memset(trie,0,sizeof(trie));
		node *root = newnode();
		for(int i = 0;i < m;++i){
			scanf("%s",buf);
			insert(root,buf,i);
		}
		buildfail(root);
		dp[0][0][0] = 1;
		for(int l = 0;l < n;++l)
		for(int i = 0;i < ncnt;++i)
		for(int s = 0;s < (1<<m);++s){
			if(!dp[l][i][s])continue;
			for(int c = 0;c < alp;++c){
				node *next = trie[i].ch[c];
				if(!next)continue;
				int &ret = dp[l+1][next->i][s|next->S];
				ret = (ret+dp[l][i][s])%mod;
			}
		}
		int ans = 0;
		for(int s = 0;s < (1<<m);++s)if(count(s)>=k){
			for(int i = 0;i < ncnt;++i)ans = (ans+dp[n][i][s])%mod;
		}
		cout<<ans<<endl;
	}
	return 0;
}

  

原文地址:https://www.cnblogs.com/xgtao984/p/5706150.html