HDU2825 Wireless Password [AC自动机+压缩DP]

  给出M个单词,问长度为N的包含不少于K个单词的字符串一共有多少个。到目前为止做的自动机的题目好像都是差不多样子。。都是包含不包含单词之类的。。

  这题用d[i][j][k]表示第i步走到字符j包含了单词集合k,因为一共只有10个单词,可以用二进制压缩状态表示这个集合。注意在建立trie图时要合并节点和它的fail节点的状态,一开始没想到这个WA了一次。状态转移方程为

  

  不加优化的代码交上去时间接近TLE。。。于是改成了滚动数组,并用for循环清0,还是要跑360ms左右。这相对于1s的时限还是很慢的。。。于是点了下Statistic,发现大家都是跑了好几百ms。。。。

#include <stdio.h>
#include <string.h>
#define MAXN 105
#define ALP 26
#define MOD 20090717

int n,m,k;
char s[11];
int next[MAXN][ALP],fail[MAXN],flag[MAXN],pos;
int newnode(){
    for(int i=0;i<ALP;i++)next[pos][i]=0;
    fail[pos]=flag[pos]=0;
    return pos++;
}
void insert(char *s,int id){
    int p=0;
    for(int i=0;s[i];i++){
        int k=s[i]-'a',&x=next[p][k];
        p=x?x:x=newnode();
    }
    flag[p]=1<<(id-1);
}
int q[MAXN],front,rear;
void makenext(){
    q[front=rear=0]=0,rear++;
    while(front<rear){
        int u=q[front++];
        for(int i=0;i<ALP;i++){
            int v=next[u][i];
            if(v==0)next[u][i]=next[fail[u]][i];
            else q[rear++]=v;
            if(u&&v){
                fail[v]=next[fail[u]][i];
                //这里注意合并状态
                flag[v]|=flag[fail[v]];
            }
        }
    }
}
int d[2][MAXN][MAXN*10];
//估计是数据比较多,直接用memset比较慢
void cleard(int x,int f){
    #define FOR(i,n) for(int i=0;i<n;i++)
    FOR(i,pos)FOR(j,f)d[x][i][j]=0;
}
int dp(){
    int full=(1<<m),cur=0;
    cleard(0,full);
    d[0][0][0]=1;
    for(int i=0;i<n;i++){
        cur^=1;
        cleard(cur,full);
        for(int u=0;u<pos;u++){
            for(int s=0;s<full;s++){
                //如果这个状态还没有到达过
                if(d[cur^1][u][s]==0)continue;
                //用这个状态跟新它的后继节点
                for(int k=0;k<ALP;k++){
                    int v=next[u][k],&x=d[cur][v][flag[v]|s];
                    x+=d[cur^1][u][s];
                    if(x>=MOD)x%=MOD;
                }
            }
        }
    }
    //统计第n步走过>=k个单词的状态的总数
    int ans=0;
    for(int s=0;s<full;s++){
        int bit=0;
        for(int i=s;i;bit++,i-=(i&-i));
        if(bit<k)continue;
        for(int i=0;i<pos;i++){
            ans+=d[cur][i][s];
            if(ans>=MOD)ans%=MOD;
        }
    }
    return ans;
}
int main(){
    //freopen("test.in","r",stdin);
    while(scanf("%d%d%d",&n,&m,&k),n||m||k){
        pos=0;newnode();
        for(int i=1;i<=m;i++){
            scanf("%s",s);
            insert(s,i);
        }
        makenext();
        int ans=dp();
        printf("%d\n",ans);
    }
    return 0;
}

  

原文地址:https://www.cnblogs.com/swm8023/p/2626535.html