AC自动机笔记

好久以前做的AC自动机题目了,趁还没忘掉赶紧回忆一下
AC自动机简单说就是预先知道所有的模式串,建立Trie树来进行目标串的匹配

【模板】AC自动机(二次加强版)

洛谷P5357
AC自动机裸题,上了个last优化过去了= =

#include <bits/stdc++.h>
using namespace std;
/*    freopen("k.in", "r", stdin);
    freopen("k.out", "w", stdout); */
//clock_t c1 = clock();
//std::cerr << "Time:" << clock() - c1 <<"ms" << std::endl;
//#pragma comment(linker, "/STACK:1024000000,1024000000")
#define de(a) cout << #a << " = " << a << endl
#define rep(i, a, n) for (int i = a; i <= n; i++)
#define per(i, a, n) for (int i = n; i >= a; i--)
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> PII;
typedef pair<double, double> PDD;
typedef vector<int, int> VII;
#define inf 0x3f3f3f3f
const ll INF = 0x3f3f3f3f3f3f3f3f;
const ll MAXN = 2e6 + 7;
const ll MAXM = 1e6 + 7;
const ll MOD = 1e9 + 7;
const double eps = 1e-6;
const double pi = acos(-1.0);
int tree[MAXN][26]; //tree[i][j]表示节点i的第j个儿子的节点编号
int flagg[MAXN];    //表示以该节点结尾是一个单词
// int cntword[MAXN];  //表示以该结点为结尾的单词数量
int sum[MAXN];
int fail[MAXN]; //fail指针
int vis[MAXN];
int mp[MAXN];
int tot; //总节点数
int last[MAXN];
void insert_(char *str, int num)
{
    int len = strlen(str);
    int root = 0;
    for (int i = 0; i < len; i++)
    {
        int id = str[i] - 'a';
        if (!tree[root][id])
            tree[root][id] = ++tot;
        sum[tree[root][id]]++;
        root = tree[root][id];
    }
    if (!flagg[root])
        flagg[root] = num;
    vis[num] = root;
}
void init() //最后清空,节省时间
{
    for (int i = 0; i <= tot; i++)
    {
        sum[i] = 0;
        flagg[i] = false;
        for (int j = 0; j < 26; j++)
            tree[i][j] = 0;
        last[i] = 0;
    }
    tot = 0; //RE有可能是这里的问题
}
void build()
{
    queue<int> q;
    for (int i = 0; i < 26; i++)
    {
        if (tree[0][i]) //如果第二层存在该结点,则将结点压入队列,fail指针指向根结点
        {
            fail[tree[0][i]] = 0;
            q.push(tree[0][i]);
        }
    }
    while (!q.empty())
    {
        int now = q.front();
        q.pop();
        for (int i = 0; i < 26; i++)
        {
            if (tree[now][i]) //存在该结点,该结点的fail指针指向父节点的fail指针指向的结点的子结点
            {
                fail[tree[now][i]] = tree[fail[now]][i];
                q.push(tree[now][i]);
                last[tree[now][i]] = flagg[fail[tree[now][i]]] ? fail[tree[now][i]] : last[fail[tree[now][i]]];
            }
            else
                tree[now][i] = tree[fail[now]][i]; //不存在该结点,直接让该子结点指向父节点的fail指针指向的结点的子结点
        }
    }
}
void query(char *s)
{
    int len = strlen(s);
    int now = 0, ans = 0;
    for (int i = 0; i < len; i++)
    {
        now = tree[now][s[i] - 'a'];
        if (flagg[now])
        {
            mp[now]++;
            // flagg[now] = 0;
        }
        int t = now;
        while (last[t])
        {
            t = last[t];
            if (flagg[t])
            {
                mp[t]++;
                // flagg[t] = 0;
            }
        }
    }
}
char s[MAXN]; //模式串
char t[MAXN]; //目标串
int main()
{
    int n;
    scanf("%d", &n);
    for (int i = 1; i <= n; i++)
    {
        scanf(" %s", s);
        insert_(s, i);
    }
    build();
    scanf(" %s", t);
    query(t);
    for (int i = 1; i <= n; i++)
        printf("%d
", mp[vis[i]]);
    return 0;
}

HDU-1277全文检索

HDU 全文检索

#include <bits/stdc++.h>
using namespace std;
/*    freopen("k.in", "r", stdin);
    freopen("k.out", "w", stdout); */
//clock_t c1 = clock();
//std::cerr << "Time:" << clock() - c1 <<"ms" << std::endl;
//#pragma comment(linker, "/STACK:1024000000,1024000000")
#define de(a) cout << #a << " = " << a << endl
#define rep(i, a, n) for (int i = a; i <= n; i++)
#define per(i, a, n) for (int i = n; i >= a; i--)
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> PII;
typedef pair<double, double> PDD;
typedef vector<int, int> VII;
#define inf 0x3f3f3f3f
const ll INF = 0x3f3f3f3f3f3f3f3f;
const ll MAXN = 1e6 + 7;
const ll MAXM = 1e6 + 7;
const ll MOD = 1e9 + 7;
const double eps = 1e-6;
const double pi = acos(-1.0);
int tree[MAXN][15]; //tree[i][j]表示节点i的第j个儿子的节点编号
bool flagg[MAXN];   //表示以该节点结尾是一个单词
int cntword[MAXN];  //表示以该结点为结尾的单词数量
int mark[MAXN];     //记录以该结点为结尾的单词的原来编号
int sum[MAXN];
int fail[MAXN]; //fail指针
int tot;        //总节点数
void insert_(char *str, int id)
{
    int len = strlen(str);
    int root = 0;
    for (int i = 0; i < len; i++)
    {
        int id = str[i] - '0';
        if (!tree[root][id])
            tree[root][id] = ++tot;
        sum[tree[root][id]]++;
        root = tree[root][id];
    }
    flagg[root] = true;
    cntword[root]++;
    mark[root] = id;
}
void init() //最后清空,节省时间
{
    for (int i = 0; i <= tot; i++)
    {
        sum[i] = 0;
        flagg[i] = false;
        cntword[i] = -1;
        for (int j = 0; j < 10; j++)
            tree[i][j] = 0;
    }
    tot = 0; //RE有可能是这里的问题
}
void build()
{
    queue<int> q;
    for (int i = 0; i < 10; i++)
    {
        if (tree[0][i]) //如果第二层存在该结点,则将结点压入队列,fail指针指向根结点
        {
            fail[tree[0][i]] = 0;
            q.push(tree[0][i]);
        }
    }
    while (!q.empty())
    {
        int now = q.front();
        q.pop();
        for (int i = 0; i < 10; i++)
        {
            if (tree[now][i]) //存在该结点,该结点的fail指针指向父节点的fail指针指向的结点的子结点
            {
                fail[tree[now][i]] = tree[fail[now]][i];
                q.push(tree[now][i]);
            }
            else
                tree[now][i] = tree[fail[now]][i]; //不存在该结点,直接让该子结点指向父节点的fail指针指向的结点的子结点
        }
    }
}
bool flag = true;
void query(char *s)
{
    int len = strlen(s);
    int now = 0;
    for (int i = 0; i < len; i++)
    {
        now = tree[now][s[i] - '0'];
        for (int j = now; j && ~cntword[j]; j = fail[j])
        {
            if (!flag)
                printf(" ");
            printf("[Key No. %d]", mark[j]);
            flag = false;
            cntword[j] = -1;
        }
    }
}
char s[MAXN]; //模式串
char t[MAXN]; //目标串
char temp[MAXN];
int main()
{
    memset(cntword, -1, sizeof(cntword));
    int n, m;
    while (~scanf("%d%d", &n, &m))
    {
        for (int i = 0; i < n; i++)
        {
            scanf(" %s", temp);
            if (!i)
                strcpy(t, temp);
            else
                strcat(t, temp);
        }
        for (int j = 1; j <= m; j++)
        {
            int num;
            scanf("%*s%*s%d] %s", &num, s);
            insert_(s, num);
        }
        build();
        printf("Found key: ");
        query(t);
        puts("");
        init();
    }
    return 0;
}
原文地址:https://www.cnblogs.com/graytido/p/11577764.html