CF 504E Misha and LCP on Tree——后缀数组+树链剖分

题目:http://codeforces.com/contest/504/problem/E

树链剖分,把重链都接起来,且把每条重链的另一种方向的也都接上,在这个 2*n 的序列上跑后缀数组。

对于询问,把两条链拆成一些重链的片段,然后两个指针枚举每个片段,用后缀数组找片段与片段的 LCP ,直到一次 LCP 的长度比两个片段的长度都小,说明两条链的 LCP 截止于此。

把重链放到序列上其实就是把 dfn 作为序列角标。

不太会实现,就借鉴(抄)了别人的代码。之后要多多回顾。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=3e5+5,M=N<<1,K=25;
int n,hd[N],xnt,to[M],nxt[M],tim,dfn1[N],dfn2[N],siz[N],son[N],dep[N],top[N],fa[N];
int sa[M],rk[M],tp[M],tx[M],ht[M][K],bin[K],lg[M];
char ch[N],s[M];
struct Node{int l,len;}a1[N],a2[N];
int Mn(int a,int b){return a<b?a:b;}
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9') ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;}
void dfs(int cr,int f)
{
  siz[cr]=1;dep[cr]=dep[f]+1;fa[cr]=f;
  for(int i=hd[cr],v;i;i=nxt[i])
    if((v=to[i])!=f)
      {
    dfs(v,cr);siz[cr]+=siz[v];
    if(siz[v]>siz[son[cr]])son[cr]=v;
      }
}
void dfsx(int cr,int fa)
{
  dfn1[cr]=++tim;
  if(son[cr])top[son[cr]]=top[cr],dfsx(son[cr],cr);
  for(int i=hd[cr],v;i;i=nxt[i])
    if((v=to[i])!=fa&&v!=son[cr])
      top[v]=v,dfsx(v,cr);
}
void Rsort(int n,int nm)
{
  for(int i=1;i<=nm;i++)tx[i]=0;
  for(int i=1;i<=n;i++)tx[rk[i]]++;
  for(int i=2;i<=nm;i++)tx[i]+=tx[i-1];
  for(int i=n;i;i--)sa[tx[rk[tp[i]]]--]=tp[i];
}
void get_sa(int n)
{
  int nm=26;
  for(int i=1;i<=n;i++)tp[i]=i,rk[i]=s[i];
  Rsort(n,nm);
  for(int k=1;k<=n;k<<=1)
    {
      int tot=0;
      for(int i=n-k+1;i<=n;i++)tp[++tot]=i;
      for(int i=1;i<=n;i++)
    if(sa[i]>k)tp[++tot]=sa[i]-k;
      Rsort(n,nm);
      swap(rk,tp);nm=1;rk[sa[1]]=1;
      for(int i=2,u,v;i<=n;i++)
    {
      u=sa[i]+k;v=sa[i-1]+k;if(u>n)u=0;if(v>n)v=0;
      rk[sa[i]]=(tp[sa[i]]==tp[sa[i-1]]&&tp[u]==tp[v])?nm:++nm;//rk[sa[i]]
    }
      if(nm==n)break;
    }
}
void get_ht(int n)
{
  int k=0,j;
  for(int i=1;i<=n;i++)//index of s[]
    {
      for(j=sa[rk[i]-1],k?k--:0;i+k<=n&&j+k<=n&&s[i+k]==s[j+k];k++);
      ht[rk[i]][0]=k;//rk[i]
    }
  lg[1]=0;for(int i=2;i<=n;i++)lg[i]=lg[i>>1]+1;
  bin[0]=1;for(int i=1;i<=lg[n];i++)bin[i]=bin[i-1]<<1;
  
  for(int j=1;j<=lg[n];j++)
    for(int i=1;i+bin[j]-1<=n;i++)
      ht[i][j]=Mn(ht[i][j-1],ht[i+bin[j-1]][j-1]);//+bin[j-1]!
}
int get_lca(int x,int y)
{
  while(top[x]!=top[y])
    {
      if(dep[top[x]]<dep[top[y]])swap(x,y);
      x=fa[top[x]];
    }
  return dep[x]<dep[y]?x:y;
}
void get_a(int x,int y,int &tot,Node *a)
{
  tot=0;int lca=get_lca(x,y);
  while(dep[top[x]]>=dep[lca])
    a[++tot]=(Node){dfn2[x],dfn2[top[x]]-dfn2[x]+1},x=fa[top[x]];
  if(dep[x]>=dep[lca])a[++tot]=(Node){dfn2[x],dfn2[lca]-dfn2[x]+1};
  int bj=tot;
  while(dep[top[y]]>dep[lca])
    a[++tot]=(Node){dfn1[top[y]],dfn1[y]-dfn1[top[y]]+1},y=fa[top[y]];
  if(dep[y]>dep[lca])a[++tot]=(Node){dfn1[son[lca]],dfn1[y]-dfn1[son[lca]]+1};
  reverse(a+bj+1,a+tot+1);
}
int get_ans(int l,int r)//l,r:index of s[]
{
  if(l==r)return (n<<1)-(l-1);
  l=rk[l]; r=rk[r]; if(l>r)swap(l,r);//rk[]!
  int d=lg[r-l];
  return Mn(ht[l+1][d],ht[r-bin[d]+1][d]);
}
int main()
{
  n=rdn();scanf("%s",ch+1);
  for(int i=1,u,v;i<n;i++)
    {
      u=rdn();v=rdn();add(u,v);add(v,u);
    }
  dfs(1,0);top[1]=1;dfsx(1,0);
  for(int i=1,j=(n<<1)+1;i<=n;i++)dfn2[i]=j-dfn1[i],s[dfn1[i]]=s[dfn2[i]]=ch[i]-'a'+1;
  get_sa(n<<1);get_ht(n<<1);
  int Q=rdn(),a,b,c,d,nm1,nm2;
  while(Q--)
    {
      a=rdn();b=rdn();c=rdn();d=rdn();
      get_a(a,b,nm1,a1);get_a(c,d,nm2,a2);
      int p1=1,p2=1,st1=0,st2=0,ans=0;
      while(p1<=nm1&&p2<=nm2)
    {
      int len=get_ans(a1[p1].l+st1,a2[p2].l+st2);
      int d=Mn(a1[p1].len-st1,a2[p2].len-st2);
      len=Mn(len,d);
      ans+=len;st1+=len;st2+=len;
      if(len<d)break;
      if(st1==a1[p1].len)st1=0,p1++;
      if(st2==a2[p2].len)st2=0,p2++;
    }
      printf("%d
",ans);
    }
  return 0;
}
原文地址:https://www.cnblogs.com/Narh/p/10077898.html