HDU 2896 病毒侵袭 【AC自动机】

HDU 2222 仅仅求出了和文本串匹配的模式串个数,本题要求求出匹配的模式串的编号。

不同的部分在代码中的注释部分。


#include <cstdio>
#include <cstring>
#include <queue>
#include <vector>
#include <algorithm>
#define MAX_NODE 60005
#define MAX_CHILD 130
using namespace std;
vector<int> ans;

class AC_Automaton {
public:
    int chd[MAX_NODE][MAX_CHILD];
    int fail[MAX_NODE];
    int val[MAX_NODE];
    int ID[130];
    int sz;
    queue<int> q;

    AC_Automaton() {
        for (int i = 0; i < 130; i++) ID[i] = i;
        Clear();
    }

    void Clear() {
        memset(chd, 0, sizeof (chd));
        memset(fail, 0, sizeof (fail));
        memset(val, 0, sizeof (val));
        sz = 1;
    }

    void Insert(const char *s, int v) {
        int cur = 1;
        for (int i = 0; s[i]; i++) {
            if (!chd[cur][ID[s[i]]]) chd[cur][ID[s[i]]] = ++sz;
            cur = chd[cur][ID[s[i]]];
        }
        //val[cur]++;
        val[cur] = v;
    }

    void Build_AC() {
        while (!q.empty()) q.pop();
        q.push(1);
        fail[1] = 1;
        while (!q.empty()) {
            int cur = q.front();
            q.pop();
            for (int i = 0; i < MAX_CHILD; i++)
                if (chd[cur][i]) {
                    if (cur == 1) fail[chd[cur][i]] = 1;
                    else {
                        int tmp = fail[cur];
                        while (tmp != 1 && chd[tmp][i] == 0) tmp = fail[tmp];
                        if (chd[tmp][i]) fail[chd[cur][i]] = chd[tmp][i];
                        else fail[chd[cur][i]] = 1;
                    }
                    q.push(chd[cur][i]);
                }
        }
    }

    int Query(const char *s) {
        int ret = 0;
        int cur = 1, tmp;
        for (int i = 0; s[i]; i++) {
            if (chd[cur][ID[s[i]]]) cur = chd[cur][ID[s[i]]];
            else {
                while (cur != 1 && chd[cur][ID[s[i]]] == 0) cur = fail[cur];
                if (chd[cur][ID[s[i]]]) cur = chd[cur][ID[s[i]]];
            }
            tmp = cur;
            while (tmp != 1 && val[tmp] != -1) {
                if (val[tmp]) ans.push_back(val[tmp]);
                //ret += val[tmp];
                //val[tmp] = -1;
                tmp = fail[tmp];
            }
        }
        return ret;
    }
} AC;

char s[210], text[10005];

int main() {
    int n, m;
    while (scanf("%d", &n) == 1) {
        AC.Clear();
        for (int i=1; i<=n; i++) {
            scanf(" %s", s); AC.Insert(s, i);
        }
        AC.Build_AC();

        scanf("%d", &m);
        int tot = 0, cnt;
        for (int i=1; i<=m ;i++) {
            scanf(" %s", &text);
            ans.clear();

            AC.Query(text);

            if (ans.size() == 0) continue;
            sort(ans.begin(), ans.end());

            printf("web %d:", i);
            for (int j=0; j<ans.size(); j++) printf(" %d", ans[j]);
            printf("
");
            tot++;
        }
        printf("total: %d
", tot);
    }
    return 0;
}


原文地址:https://www.cnblogs.com/javawebsoa/p/3241066.html