BZOJ2434 [Noi2011]阿狸的打字机 [AC自动机, 树状数组]

阿狸的打字机

题目描述见链接 .


color{red}{正解部分}

按题意可建出 TrieTrie树, B 退格表示退回到 TrieTrie树中的父节点 .

在这个 TrieTrie树的基础上 建立 AcAc 自动机, 构建 failfail 树,

对于一个询问 (x,y)(x, y), 只需求出: TrieTrie树 中 rootrootyy 路径上的点有多少点的 failfail 树中的祖先 包含 xx 节点 .

为什么?? 这是因为 rootyroot ightarrow y 的路径上的节点表示的是 yy字符串 的 前缀, 而 failfail树上 一个节点的祖先为这个节点的 后缀,
所以上方相当于求出 yy字符串 所有前缀 的 所有后缀 有多少与 xx 匹配, 等同于 xxyy 中出现了多少次 .


现在考虑怎么去解决这个问题,
:首先上暴力:
对每个询问 (x,y)(x,y), 从 rootroot 走到 yy, 路上不断跳 failfail, 使用 存上 xx 节点出现次数, 如果实现方法较优, 可以得到 70pts70pts.


:接下来考虑怎么优化:
上方暴力算法有 22 点严重拖慢程序速度:

  1. 多次从 rootroot 重新出发走到 yy .
  2. failfail 找祖先 .
  • 对于第 11 点, 可以在 TrieTrie树 上标记每个 yy 所在的位置, 然后从 rootroot 出发, 遍历整棵树, 到达一个 yy, 就进行处理 .

  • 对于第 22 点, 在 failfail树 上 xxpp 的祖先, 可以看做 failfail树上 xx 的子树包含 pp,
    于是只需求出 rootyroot ightarrow y 有多少节点在 failfail树中 在 xx 的子树内,
    这个问题就好解决了, 求出所有节点在 failfail树上 的 dfsdfs序, 结合在 TrieTrie树上的DFSDFS回溯操作 和 树状数组 维护即可 .


color{red}{实现部分}

  • 由于虚拟节点的存在, 树状数组 的上界需要开大 11 .
  • 注意 TrieTrie树 一定要拷贝一份, 在建 ACAC自动机 的过程中会使得 TrieTrie树 的儿子节点变化, 导致在后面 DFSDFS 时爆炸 .
#include<bits/stdc++.h>
#define reg register
#define pb push_back

const int maxn = 2e5 + 5;

int read(){
        char c;
        int s = 0, flag = 1;
        while((c=getchar()) && !isdigit(c))
                if(c == '-'){ flag = -1, c = getchar(); break ; }
        while(isdigit(c)) s = s*10 + c-'0', c = getchar();
        return s * flag;
}

int N;
int M;
int cur;
int num0;
int dfs_tim;
int str_cnt;
int node_cnt;
int Fa[maxn];
int Mp[maxn];
int end[maxn];
int dfn[maxn];
int Ans[maxn];
int head[maxn];

char Smp_1[maxn];

struct Que{ int x, id; Que(int x=0, int id=0):x(x), id(id) {} };

struct Node{ int nxt, Ch[30], ch[30]; } Trie_t[maxn];

struct Edge{ int nxt, to; } edge[maxn << 1];

struct Bit_tree{
        int lim;
        int v[maxn];
        void Add(int k, int x){ while(k<=lim)v[k]+=x,k+=k&-k; }
        int Query(int k){ if(!k)return 0; int s=0; while(k)s+=v[k],k-=k&-k; return s; }
} bit_t;

std::vector <Que> B[maxn];

void Add(int from, int to){ edge[++ num0] = (Edge){ head[from], to }; head[from] = num0; }

void Build_Ac(){
        std::queue <int> Q;
        for(reg int i = 0; i < 26; i ++) if(Trie_t[0].ch[i]) Q.push(Trie_t[0].ch[i]);
        while(!Q.empty()){
                int ft = Q.front(); Q.pop();
                for(reg int i = 0; i < 26; i ++){
                        int &to = Trie_t[ft].ch[i];
                        if(to) Trie_t[to].nxt = Trie_t[Trie_t[ft].nxt].ch[i], Q.push(to);
                        else to = Trie_t[Trie_t[ft].nxt].ch[i];
                }
        }
}

void DFS_fail(int k, int fa){
        dfn[k] = ++ dfs_tim;
        for(reg int i = head[k]; i; i = edge[i].nxt){
                int to = edge[i].to;
                if(to == fa) continue ;
                DFS_fail(to, k);
        }
        end[k] = dfs_tim;
}

void DFS(int k){
        bit_t.Add(dfn[k], 1); 
        int siz = B[k].size(); 
        for(reg int j = 0; j < siz; j ++){ 
                int x = B[k][j].x;
                Ans[B[k][j].id] = bit_t.Query(end[x]) - bit_t.Query(dfn[x]-1); 
        }
        for(reg int i = 0; i < 26; i ++){ int to = Trie_t[k].Ch[i]; if(!to) continue ; DFS(to); }
        bit_t.Add(dfn[k], -1);
}

int main(){
        freopen("a.in", "r", stdin);
        freopen("a.out", "w", stdout);
        scanf("%s", Smp_1+1);
        N = strlen(Smp_1+1);
        cur = 0;
        for(reg int i = 1; i <= N; i ++){
                char t = Smp_1[i];
                if(t == 'B') cur = Fa[cur];
                else if(t == 'P') Mp[++ str_cnt] = cur;
                else{
                        if(!Trie_t[cur].ch[t-'a']) Trie_t[cur].ch[t-'a'] = ++ node_cnt;
                        Fa[Trie_t[cur].ch[t-'a']] = cur;
                        cur = Trie_t[cur].ch[t-'a'];
                }
        }
        for(reg int i = 0; i <= node_cnt; i ++)
                for(reg int j = 0; j < 26; j ++) Trie_t[i].Ch[j] = Trie_t[i].ch[j];
        Build_Ac();
        for(reg int i = 1; i <= node_cnt; i ++) Add(Trie_t[i].nxt, i);
        DFS_fail(0, 0);
        M = read();
        for(reg int i = 1; i <= M; i ++){ int x = read(), y = read(); B[Mp[y]].pb(Que(Mp[x], i)); }
        bit_t.lim = node_cnt+1; DFS(0);
        for(reg int i = 1; i <= M; i ++) printf("%d
", Ans[i]);
        return 0;
}
原文地址:https://www.cnblogs.com/zbr162/p/11822490.html