BZOJ4231 回忆树

(kmp) 暴力处理经过 (lca) 的匹配,这一部分复杂度为 (O(sum|s|))。然后就只用考虑直上直下的链的匹配,离线后对询问串建 (AC) 自动机,在原树上遍历时加入贡献,答案差分统计,即长链的匹配减去短链的匹配,用树状数组维护 (fail) 树子树和即可。

#include<bits/stdc++.h>
#define maxn 300010
#define lowbit(x) (x&(-x))
using namespace std;
template<typename T> inline void read(T &x)
{
    x=0;char c=getchar();bool flag=false;
    while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    if(flag)x=-x;
}
int n,m,root;
int ans[maxn],f[maxn][19],dep[maxn],nxt[maxn];
char s[maxn],col[maxn],str[maxn],tmp[maxn];
struct node
{
    int p,v,id;
    node(int a=0,int b=0,int c=0)
    {
        p=a,v=b,id=c;
    }
};
vector<node> q1[maxn],q2[maxn];
struct edge
{
    int to,nxt;
    char v;
    edge(int a=0,int b=0,char c=0)
    {
        to=a,nxt=b,v=c;
    }
}e[maxn];
int head[maxn],edge_cnt;
void add(int from,int to,char val)
{
    e[++edge_cnt]=edge(to,head[from],val),head[from]=edge_cnt;
}
struct AC
{
    int tot,cnt;
    int ch[maxn][28],fail[maxn],in[maxn],out[maxn],tr[maxn];
    vector<int> ve[maxn];
    void update(int x,int v)
    {
        if(!x) return;
        x=in[x];
        while(x<=n) tr[x]+=v,x+=lowbit(x);
    }
    int ask(int x)
    {
        int v=0;
        while(x) v+=tr[x],x-=lowbit(x);
        return v;
    }
    int query(int x)
    {
        return ask(out[x])-ask(in[x]-1);
    }
    int insert(int type=0)
    {
        int p=root,len=strlen(s+1);
        if(type) reverse(s+1,s+len+1);
        for(int i=1;i<=len;++i)
        {
            int c=s[i]-'a';
            if(!ch[p][c]) ch[p][c]=++tot;
            p=ch[p][c];
        }
        if(type) reverse(s+1,s+len+1);
        return p;
    }
    void dfs_dfn(int x)
    {
        in[x]=++cnt;
        for(int i=0;i<ve[x].size();++i) dfs_dfn(ve[x][i]);
        out[x]=cnt;
    }
    void build()
    {
        queue<int> q;
        for(int c=0;c<26;++c)
            if(ch[root][c])
                q.push(ch[root][c]);
        while(!q.empty())
        {
            int x=q.front();
            q.pop();
            for(int c=0;c<26;++c)
            {
                int y=ch[x][c];
                if(y) fail[y]=ch[fail[x]][c],q.push(y);
                else ch[x][c]=ch[fail[x]][c];
            }
        }
        for(int i=1;i<=tot;++i) ve[fail[i]].push_back(i);
        dfs_dfn(root);
    }
}A,B;
void dfs_pre(int x,int fa)
{
    dep[x]=dep[f[x][0]=fa]+1;
    for(int i=1;i<=17;++i) f[x][i]=f[f[x][i-1]][i-1];
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to;
        if(y==fa) continue;
        col[y]=e[i].v,dfs_pre(y,x);
    }
}
int lca(int x,int y)
{
    if(dep[x]<dep[y]) swap(x,y);
    for(int i=17;i>=0;--i)
        if(f[x][i]&&dep[f[x][i]]>=dep[y])
            x=f[x][i];
    if(x==y) return x;
    for(int i=17;i>=0;--i)
        if(f[x][i]&&f[x][i]!=f[y][i])
            x=f[x][i],y=f[y][i];
    return f[x][0];
}
int get(int x,int k)
{
    for(int i=0;i<=17;++i)
        if((k>>i)&1)
            x=f[x][i];
    return x;
}
void work(int x,int y,int id)
{
    int anc=lca(x,y),p1=A.insert(),p2=B.insert(1),len=strlen(s+1),p,cnt1=0,cnt2=0,pos=0;
    p=get(x,max(dep[x]-dep[anc]-len+1,0));
    q2[x].push_back(node(p2,1,id)),q2[p].push_back(node(p2,-1,id));
    while(p!=anc) str[++cnt1]=col[p],p=f[p][0];
    p=get(y,max(dep[y]-dep[anc]-len+1,0));
    q1[y].push_back(node(p1,1,id)),q1[p].push_back(node(p1,-1,id));
    while(p!=anc) tmp[++cnt2]=col[p],p=f[p][0];
    for(int i=cnt2;i;--i) str[++cnt1]=tmp[i];
    for(int i=1;i<=len;++i) nxt[i]=0;
    for(int i=2;i<=len;++i)
    {
        while(pos&&s[pos+1]!=s[i]) pos=nxt[pos];
        nxt[i]=(pos+=s[pos+1]==s[i]);
    }
    pos=0;
    for(int i=1;i<=cnt1;++i)
    {
        while(pos&&s[pos+1]!=str[i]) pos=nxt[pos];
        pos+=s[pos+1]==str[i];
        if(pos==len) ans[id]++,pos=nxt[pos];
    }
}
void dfs_ans(int x,int p1,int p2)
{
    A.update(p1,1),B.update(p2,1);
    for(int i=0;i<q1[x].size();++i)
        ans[q1[x][i].id]+=A.query(q1[x][i].p)*q1[x][i].v;
    for(int i=0;i<q2[x].size();++i)
        ans[q2[x][i].id]+=B.query(q2[x][i].p)*q2[x][i].v;
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to,v=e[i].v-'a';
        if(y==f[x][0]) continue;
        dfs_ans(y,A.ch[p1][v],B.ch[p2][v]);
    }
    A.update(p1,-1),B.update(p2,-1);
}
int main()
{
    read(n),read(m);
    for(int i=1;i<n;++i)
    {
        int x,y;
        read(x),read(y),scanf("%s",s),add(x,y,s[0]),add(y,x,s[0]);
    }
    dfs_pre(1,0);
    for(int i=1;i<=m;++i)
    {
        int x,y;
        read(x),read(y),scanf("%s",s+1);
        if(x!=y) work(x,y,i);
    }
    A.build(),B.build(),dfs_ans(1,root,root);
    for(int i=1;i<=m;++i) printf("%d
",ans[i]);
    return 0;
}
原文地址:https://www.cnblogs.com/lhm-/p/14220945.html