HDU 2825 Wireless Password ( Trie图 && 状态压缩DP )

题意 : 输入n、m、k意思就是给你 m 个模式串,问你构建长度为 n 至少包含 k 个模式串的方案有多少种

分析 : ( 以下题解大多都是在和 POJ 2778 && POJ 1625 && HDU 2243 进行类比,如果没做过的话.......可能看不懂 )

这道题如果去对比之前做过的 POJ 2778 And HDU 2243 可以发现现在的难点在于如何找出至少包含 k 个模式串的,这里我们给每一个单词编号,对于在DP过程当中选中了这个单词就标记一下,但是问题是如何判断是否被重复选过以及如何标记?如果直接开在DP的维度上那是不可能的!(即开一个十维数组来记录每一个单词是否被选过),如果能将这十个维度压缩成一个维度多好啊!没错了,使用二进制,二进制的每一位代表每一个单词,0就是没被选,1就是被选,最多只需要十位便能记录这些信息,转化为十进制之后我们只要给DP再多添一个维度即可。那么代码怎么实现呢?只需要用到 | 操作符即可,这个操作符会把每一位的 1 进行叠加,相当于吧信息进行叠加!不过这里需要注意,在做POJ 1625 的时候是用一个矩阵来转移的,而这道题不能,因为我们还附加了一个选了哪几个单词这一个维度,如果全部笼统的映射到矩阵上的话是不行的!比如 G[i][j] 代表从 i 到 j 一步走完的可行方案,但是我们不知道这些方案到底走过了那些单词,所以需要一步步转移,如果把 POJ 1625 的代码不用矩阵来储存信息那要怎么写呢?如下( 如果你之前看过我 POJ 1625 的代码的话 )

        for(int i=0; i<m; i++)
        for(int j=0; j<ac.Size; j++){
//            for(int k=0; k<ac.Size; k++){
//                dp[i+1][k] += dp[i][j] * G[j][k];
//            }
            for(int k=0; k<n; k++){
                int tmp = ac.Node[j].Next[k];
                if(!ac.Node[tmp].flag)
                    dp[i+1][tmp] += dp[i][j];
            }
        }

所以在网上搜这道题的题解的时候,一开始我有点懵,为什么有四层for?这和用矩阵有什么区别?实际上矩阵写法完全可以写成如上所示!

因此总结一下就是 利用DP[i][j][k] 表示 DP[第几步][哪个节点结尾][当前选了哪些单词] = 方案数

这里有个大优化,因为我们的DP是一个向前的DP,这样的DP有一个优势,就是当 DP值 == 0 的时候将不会对后面有任何影响,直接continue便能减去大量的“枝”!

#include<string.h>
#include<stdio.h>
#include<queue>
using namespace std;
const int Max_Tot = 111;
const int Letter = 26;
const int MOD = 20090717;
int dp[30][111][(1<<10)+1];
int cnt[1111];
struct Aho{
    struct StateTable{
        int Next[Letter];
        int fail, id;
    }Node[Max_Tot];
    int Size;
    queue<int> que;

    inline void init(){
        while(!que.empty()) que.pop();
        memset(Node[0].Next, 0, sizeof(Node[0].Next));
        Node[0].fail = Node[0].id = 0;
        Size = 1;
    }

    inline void insert(char *s, int id){
        int now = 0;
        for(int i=0; s[i]; i++){
            int idx = s[i] - 'a';
            if(!Node[now].Next[idx]){
                memset(Node[Size].Next, 0, sizeof(Node[Size].Next));
                Node[Size].fail = Node[Size].id = 0;
                Node[now].Next[idx] = Size++;
            }
            now = Node[now].Next[idx];
        }
        Node[now].id |= (1<<id);
    }

    inline void BuildFail(){
        Node[0].fail = 0;
        for(int i=0; i<Letter; i++){
            if(Node[0].Next[i]){
                Node[Node[0].Next[i]].fail = 0;
                que.push(Node[0].Next[i]);
            }else Node[0].Next[i] = 0;
        }
        while(!que.empty()){
            int top = que.front(); que.pop();
            Node[top].id |= Node[Node[top].fail].id;
            for(int i=0; i<Letter; i++){
                int &v = Node[top].Next[i];
                if(v){
                    que.push(v);
                    Node[v].fail = Node[Node[top].fail].Next[i];
                }else v = Node[Node[top].fail].Next[i];
            }
        }
    }
}ac;
char S[20];
int main(void)
{
    for(int i=0; i<(1<<10); i++){
        cnt[i] = 0;
        for(int j=0; j<10; j++){
            if(i & (1<<j))
                cnt[i]++;
        }
    }

    int n, m, k;
    while(~scanf("%d %d %d", &n, &m, &k)){
        if(n==0 && m==0 && k==0) break;
        ac.init();
        for(int i=0; i<m; i++){
            scanf("%s", S);
            ac.insert(S, i);
        }
        ac.BuildFail();

        for(int i=0; i<=n; i++)
            for(int j=0; j<ac.Size; j++)
                for(int l=0; l<(1<<m); l++)
                    dp[i][j][l] = 0;

        dp[0][0][0] = 1;

        for(int i=0; i<n; i++){
            for(int j=0; j<ac.Size; j++){
                for(int l=0; l<(1<<m); l++){
                    if(dp[i][j][l] > 0){
                        for(int x=0; x<Letter; x++){///不用矩阵了,一步步转移
                            int newi = i+1;
                            int newj = ac.Node[j].Next[x];
                            int newl = (ac.Node[ newj ].id) | l;///加入了哪个新节点就应该去叠加状态!
                            dp[newi][newj][newl] += dp[i][j][l];
                            dp[newi][newj][newl] %= MOD;
                        }
                    }
                }
            }
        }

        int ans = 0;
        for(int i=0; i<(1<<m); i++){
            if(cnt[i]>=k){
                for(int j=0; j<ac.Size; j++){
                    ans = (ans + dp[n][j][i])%MOD;
                }
            }
        }

        printf("%d
", ans%MOD);
    }
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/qwertiLH/p/7632998.html