【hdu2825】ac自动机 + 状压dp

传送门

题目大意:

给你一些密码片段字符串,让你求长度为n,且至少包含k个不同密码片段串的字符串的数量。

题解:

因为密码串不多,可以考虑状态压缩

设dp[i][j][sta]表示长为i的字符串匹配到j节点且状态为sta的数量。

其中sta存储的是包含的密码串情况,在构建fail指针时,当前节点要并上fail指针所指的节点。

跑ac自动机,儿子节点从父亲节点转移。

最后取dp[len][...][sta]的和,其中sta满足二进制中1的数量>=k,

这一点可以像树状数组的lowbit那样快速求出:

inline int count(int x){
    int ret = 0;
    while(x){
        ret++;
        x -= (x & -x);
    }
    return ret;
}

code

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<string>
#include<algorithm>
#include<queue>
using namespace std;
const int N = 20, L = 20, Mod = 20090717;
int n, m, k, tot;
long long dp[30][110][1100], ans;
char s[20];
queue<int> que;
struct node{
    int trans[27];
    int fail, no;
    int state;
    inline void clear(){
        memset(trans, 0, sizeof trans);
        fail = state = no = 0;
    }
}trie[1010];
inline int getVal(char st){
    return st - 'a' + 1;
}
inline void insert(int num){
    int len = strlen(s + 1), pos = 1;
    for(int i = 1; i <= len; i++){
        int val = getVal(s[i]);
        if(!trie[pos].trans[val])
            trie[trie[pos].trans[val] = ++tot].clear();
        pos = trie[pos].trans[val];
    }
    trie[pos].state |= 1 << num;
}
inline void buildFail(){
    for(int i = 1; i <= 26; i++) trie[0].trans[i] = 1;
    que.push(1);
    while(!que.empty()){
        int u = que.front(); que.pop();
        for(int i = 1; i <= 26; i++){
            int v = trie[u].fail;
            while(!trie[v].trans[i]) v = trie[v].fail;
            int w = trie[u].trans[i];
            v = trie[v].trans[i];
            if(w){
                trie[w].fail = v;
                que.push(w);
                trie[w].state |= trie[v].state;
            }
            else trie[u].trans[i] = v;
        }
    }
}
inline int count(int x){
    int ret = 0;
    while(x){
        ret++;
        x -= (x & -x);
    }
    return ret;
}
inline void solve(){
    memset(dp, 0, sizeof dp);
    int limit = 1 << m;
    dp[0][1][0] = 1;
    for(int i = 1; i <= n; i++)
        for(int j = 1; j <= tot; j++)
            for(int sta = 0; sta < limit; sta++)
                if(dp[i - 1][j][sta])
                    for(int l = 1; l <= 26; l++){
                        int u = trie[j].trans[l];
                        dp[i][u][sta | trie[u].state] = (dp[i][u][sta | trie[u].state] + dp[i - 1][j][sta]) % Mod;
                    }
    for(int i = 1; i <= tot; i++)
        for(int sta = 0; sta < limit; sta++){
            if(count(sta) >= k)
                ans = (ans + dp[n][i][sta]) % Mod;
        }
}
int main(){
    while(scanf("%d%d%d", &n, &m, &k), n + m + k){
        trie[tot = 1].clear(); ans = 0;
        for(int i = 1; i <= m; i++){
            scanf("%s", s + 1);
            insert(i - 1);
        }
        buildFail();
        solve();
        cout << ans << endl;
    }
}
原文地址:https://www.cnblogs.com/CzYoL/p/7450429.html