阿狸的打字机(AC自动机+fail树)

bzoj2434: [Noi2011]阿狸的打字机
阿狸喜欢收藏各种稀奇古怪的东西,最近他淘到一台老式的打字机。打字机上只有28个按键,分别印有26个小写英文字母和’B’、’P’两个字母。

经阿狸研究发现,这个打字机是这样工作的:
l 输入小写字母,打字机的一个凹槽中会加入这个字母(这个字母加在凹槽的最后)
l 按一下印有’B’的按键,打字机凹槽中最后一个字母会消失。
l 按一下印有’P’的按键,打字机会在纸上打印出凹槽中现有的所有字母并换行,但凹槽中的字母不会消失。
例如,阿狸输入aPaPBbP,纸上被打印的字符如下:
a
aa
ab
我们把纸上打印出来的字符串从1开始顺序编号,一直到n。打字机有一个非常有趣的功能,在打字机中暗藏一个带数字的小键盘,在小键盘上输入两个数(x,y)(其中1≤x,y≤n),打字机会显示第x个打印的字符串在第y个打印的字符串中出现了多少次。
阿狸发现了这个功能以后很兴奋,他想写个程序完成同样的功能,你能帮助他么?

Input
输入的第一行包含一个字符串,按阿狸的输入顺序给出所有阿狸输入的字符。
第二行包含一个整数m,表示询问个数。
接下来m行描述所有由小键盘输入的询问。其中第i行包含两个整数x, y,表示第i个询问为(x, y)。

Output
输出m行,其中第i行包含一个整数,表示第i个询问的答案。

Sample Input
aPaPBbP
3
1 2
1 3
2 3

Sample Output
2
1
0

HINT
1<=N<=10^5
1<=M<=10^5
输入总长<=10^5

分析:网上的题解在我这个zz看来真是晦涩难懂啊,所以只能一点点研究,写了一篇蒟蒻看的题解(大神勿喷):
这道题需要用到fail树(实际上是所有fail指针反向构成的树),简单解释一下:
在fail树中,假使有一个节点对应的字符串为aa,那么所有以aa为后缀字符串都在这个节点的子树里。如果串x出现在串y中,那么串y有几个前缀的后缀以串x结尾,便是出现的次数。而这些y串上的节点都会出现在x的子树中。
简单地说串y从某个位置顺着fail指针能到达串x尾就增加一次。
我们需要计算的就是以x串最后一个字符所在的节点为根的子树中找到包含多少个y串中的节点,把x串最后一个字符所在节点为根的子树中的y串上的节点独立出来,把这些节点的值赋为1,对于所有询问我们先按照y排一遍序,以后就可以把每一个y标记完的树,处理所有对应的(x[i],y)询问,求一个区间和(树状数组维护)

这里写代码片
#include<cstdio>
#include<cstring>
#include<iostream>
#include<queue>
#include<algorithm>

using namespace std;

const int N=100010;
int n,m,x,y;
char s[N];
int ch[N][26],fa[N],tot=0;  //tot:节点数 
int word[N];  //word:每个输出字符串对应结尾 
int tt=0,fail[N],in[N],out[N],ed=-1;  //ed:辅助dfs计数  tt:辅助记录输出字符串的编号 
struct node{
    int x,y,next;
};
node tree[N*4]; //fail树 
int st[N],totw=0,ans[N];  //ans:记录答案 
struct node2{
    int x,y,id;
};
node2 qes[N];  //记录询问 
int  C[N];  //树状数组 

int comp(const node2 & a,const node2 & b)
{
    if (a.y!=b.y) return a.y<b.y;  //按y(也就是‘字典’)排序,在使用每一个字典时,把所有关于他的询问都处理完 
    else return a.x<b.x;
}

void add(int u,int w)
{
    totw++;
    tree[totw].x=u;
    tree[totw].y=w;
    tree[totw].next=st[u];
    st[u]=totw;
    return;
}

void build()   //trie树的构建 
{
    int i,now=0;
    int len=strlen(s);
    for (i=0; i<len; i++) 
    {
        int x=s[i]-'a';
        if (s[i]=='P') 
        {
            word[++tt]=now;
            continue;
        }
        if (s[i]=='B') 
        {
            now=fa[now];
            continue;
        }
        if (!ch[now][x]) ch[now][x]=++tot,fa[ch[now][x]]=now;
        now=ch[now][x];
    }
    return;  
}

void make() //生成失配指针。顺便反向建树。
{
    int i;
    queue<int> q;
    for (i=0; i<26; i++)
        if (ch[0][i])
            q.push(ch[0][i]);
    while (!q.empty()) {
        int r=q.front();
        q.pop();
        for (i=0; i<26; i++) 
        {
            if (!ch[r][i]) 
            {
                ch[r][i]=ch[fail[r]][i];
                continue;
            }
            fail[ch[r][i]]=ch[fail[r]][i];
            q.push(ch[r][i]);
        }
    }
    for (i=1;i<=tot;i++)
       add(fail[i],i);
    return; 
}

void dfs(int t)  //对fail树dfs 根节点是0所以ed初始值是-1,保证t和ed的值能对应的上 
{
    int i;
    in[t]=++ed;  //节点t在dfs中的起始位置 
    for (i=st[t];i;i=tree[i].next)
    {
        if (tree[i].y!=t)
           dfs(tree[i].y);
    }
    out[t]=ed;  //在dfs中的终止位置 
}

//树状数组时建立在dfs序上的,这样能确保同一颗子树上的节点在序列中是连续的一段 
void change(int bh,int z)
{
    int i;
    for (i=bh;i<=tot;i+=i&(-i))
        C[i]+=z;
    return;
}

int ask(int bh)
{
    int i,an=0;
    for (i=bh;i;i-=i&(-i))
        an+=C[i];
    return an;
}

void solve() //把字符串重新遍历一遍 
{
    int i,now=0,js=0,j=1;
    int len=strlen(s);
    for (i=0;i<len;i++)
    {
        int x=s[i]-'a';
        if (x>=0&&x<26)
        {
            now=ch[now][x];
            change(in[now],1);
        }
        else if (s[i]=='B')
        {
            change(in[now],-1);
            now=fa[now];
        }
        else
        {
            js++;
            while (qes[j].y==js)
            {
                ans[qes[j].id]=ask(out[word[qes[j].x]])-ask(in[word[qes[j].x]]-1);
                j++;
            }
        }
    }
    return;
}

int main() 
{
    scanf("%s",&s);
    build();
    make();
    dfs(0);
    scanf("%d",&m);
    for (int i=1;i<=m;i++)
    {
        scanf("%d%d",&x,&y);
        qes[i].id=i;
        qes[i].x=x;
        qes[i].y=y;
    }
    sort(qes+1,qes+1+m,comp);
    solve();
    for (int i=1;i<=m;i++)
        printf("%d
",ans[i]); 
    return 0;
}
原文地址:https://www.cnblogs.com/wutongtong3117/p/7673645.html