题解 洛谷 P4218 【[CTSC2010]珠宝商】

首先对特征字符串建 (SAM),来实现对子串的匹配。

有一个 (O(n^2)) 的暴力,分别以每个点为根进行 (dfs),遍历树时记录当前字符串在 (SAM) 上匹配到的节点即可。

考虑用点分治来解决本题这样的树上路径统计问题。对于当前的分治重心 (x),统计连通块中经过 (x) 的路径的贡献。经过 (x) 的路径拆分为 (x) 子树内一个点到 (x) 的路径和 (x)(x) 子树内一个点的路径。因为 (SAM) 可以用 (endpos) 统计以 (x) 所对应的字符结束的子串,所以对特征字符串的反串也建出 (SAM),把第二种路径也转化为第一种路径来便于统计。

(x) 出发 (dfs) 统计路径时,到达一个节点后要往当前字符串前端加入该节点所对应的字符,直接用 (SAM) 是无法处理匹配的。发现在 (SAM) 对应的 (Parent) 树上,一个点到其儿子节点,就是在其对应的子串前端加入字符,所以可以处理出 (Parent) 树上每个点加入字符后对应的儿子节点,这样就可以通过 (Parent) 树来实现前端加入字符来匹配了。这里其实就是建出了后缀树。

两条路径对应的字符串要保证是在特征字符串上是相邻的,因此在 (Parent) 树上打标记,统计时遍历整棵 (Parent) 树来使标记下放到儿子,对于特征字符串的每个位置,两种路径的方案相乘来贡献答案。

直接点分治的复杂度是 (O(n log n + nm)) 的,仍然无法接受,考虑结合 (O(n^2)) 的暴力进行根号分治。点分治时,连通块大小 (leqslant sqrt n) 时采取 (O(n^2)) 暴力,大小 (> sqrt n) 时采取统计路径贡献。

对这种方法来进行复杂度分析。连通块大小 (leqslant sqrt n) 的情况复杂度为 (O(n sqrt n))。考虑点分治进行到第 (k) 层时,最大的连通块大小为 (frac{n}{2^k}),连通块个数为 (2^k),若限制连通块大小最小为 (sqrt n),连通块个数的级别就为 (sqrt n) 了,所以连通块大小 (> sqrt n) 的情况复杂度为 (O(m sqrt n))。得总复杂度为 (O((n+m) sqrt n))

注意统计路径贡献时还需容斥,减去两条路径来自同一个子树的情况。容斥时也需要对每个儿子进行根号分治来保证复杂度。

(code:)

#include<bits/stdc++.h>
#define maxn 100010
using namespace std;
typedef long long ll;
template<typename T> inline void read(T &x)
{
    x=0;char c=getchar();bool flag=false;
    while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    if(flag)x=-x;
}
int n,m,S,root,tot,goal,rt;
ll ans;
int siz[maxn],ma[maxn];
bool vis[maxn];
char c[maxn],s[maxn];
struct edge
{
    int to,nxt;
}e[maxn];
int head[maxn],edge_cnt;
void add(int from,int to)
{
    e[++edge_cnt]=(edge){to,head[from]};
    head[from]=edge_cnt;
}
struct SAM
{
    int tot=1,root=1,las=1;
    int fa[maxn],ch[maxn][30],son[maxn][30],len[maxn],siz[maxn],pos[maxn],bel[maxn];
    ll tag[maxn];
    char s[maxn];
    vector<int> ve[maxn];
    void insert(int c,int id)
    {
        int p=las,np=las=++tot;
        len[np]=len[p]+1,siz[np]=1,pos[np]=id,bel[id]=np;
        while(p&&!ch[p][c]) ch[p][c]=np,p=fa[p];
        if(!p) fa[np]=root;
        else
        {
            int q=ch[p][c];
            if(len[q]==len[p]+1) fa[np]=q;
            else
            {
                int nq=++tot;
                memcpy(ch[nq],ch[q],sizeof(ch[q]));
                len[nq]=len[p]+1,fa[nq]=fa[q],fa[q]=fa[np]=nq;
                while(ch[p][c]==q) ch[p][c]=nq,p=fa[p];
            }
        }
    }
    void dfs(int x)
    {
        for(int i=0;i<ve[x].size();++i)
        {
            int y=ve[x][i];
            dfs(y),siz[x]+=siz[y];
            pos[x]=pos[y],son[x][s[pos[y]-len[x]]]=y;
        }
    }
    void build()
    {
        for(int i=1;i<=m;++i) insert(s[i],i);
        for(int i=2;i<=tot;++i) ve[fa[i]].push_back(i);
        dfs(root);
    }
    void match(int x,int fa,int p,int lenth)
    {
        if(lenth==len[p]) p=son[p][c[x]];
        else if(s[pos[p]-lenth]!=c[x]) p=0;
        if(!p) return;
        tag[p]++;
        for(int i=head[x];i;i=e[i].nxt)
        {
            int y=e[i].to;
            if(vis[y]||y==fa) continue;
            match(y,x,p,lenth+1);
        }
    }
    void update(int x)
    {
        for(int i=0;i<ve[x].size();++i)
            tag[ve[x][i]]+=tag[x],update(ve[x][i]);
    }
    void clear()
    {
        for(int i=1;i<=tot;++i) tag[i]=0;
    }
}A,B;
void dfs_root(int x,int fa)
{
    siz[x]=1,ma[x]=0;
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to;
        if(vis[y]||y==fa) continue;
        dfs_root(y,x),siz[x]+=siz[y];
        ma[x]=max(ma[x],siz[y]);
    }
    ma[x]=max(ma[x],tot-siz[x]);
    if(ma[x]<ma[root]) root=x;
}
void dfs_get(int x,int fa,int p,int type)
{
    p=A.ch[p][c[x]];
    if(!p) return;
    ans+=A.siz[p]*type;
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to;
        if(vis[y]||y==fa) continue;
        dfs_get(y,x,p,type);
    }
}
void dfs_del(int x,int fa,int p)
{
    if(x!=goal) p=A.ch[p][c[x]];
    else
    {
        p=A.ch[p][c[x]],p=A.ch[p][c[rt]];
        if(p) dfs_get(x,0,p,-1);
        return;
    }
    if(!p) return;
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to;
        if(vis[y]||y==fa) continue;
        dfs_del(y,x,p);
    }
}
void dfs_find(int x,int fa,int type)
{
    if(type) dfs_get(x,0,1,1);
    else dfs_del(x,0,1);
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to;
        if(vis[y]||y==fa) continue;
        dfs_find(y,x,type);
    }
}
void calc(int x,int fa)
{
    A.clear(),B.clear();
    if(!fa) A.match(x,fa,1,0),B.match(x,fa,1,0);
    else A.match(x,fa,A.ch[1][c[fa]],1),B.match(x,fa,B.ch[1][c[fa]],1);
    A.update(1),B.update(1);
    for(int i=1;i<=m;++i)
    {
        if(!fa) ans+=A.tag[A.bel[i]]*B.tag[B.bel[m-i+1]];
        else ans-=A.tag[A.bel[i]]*B.tag[B.bel[m-i+1]];
    }
}
void solve(int x)
{
    if(tot<=S)
    {
        dfs_find(x,0,1);
        return;
    }
    int now=tot;
    vis[x]=true,calc(x,0);
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to;
        if(vis[y]) continue;
        root=0,tot=siz[y];
        if(siz[y]>siz[x]) tot=now-siz[x];
        if(tot<=S) rt=x,goal=y,dfs_find(y,x,0);
        else calc(y,x);
        dfs_root(y,x),solve(root);
    }
} 
int main()
{
    read(n),read(m),S=sqrt(n);
    for(int i=1;i<n;++i)
    {
        int x,y;
        read(x),read(y);
        add(x,y),add(y,x);
    }
    scanf("%s%s",c+1,s+1);
    for(int i=1;i<=n;++i) c[i]-='a';
    for(int i=1;i<=m;++i) s[i]-='a',A.s[i]=s[i],B.s[m-i+1]=s[i];
    A.build(),B.build(),tot=ma[0]=n,dfs_root(1,0),solve(root);
    printf("%lld",ans);
    return 0;
}
原文地址:https://www.cnblogs.com/lhm-/p/13418355.html