洛谷 4384 [八省联考2018]制胡窜——后缀自动机+线段树合并

题目:https://www.luogu.org/problemnew/show/P4384

这个题解说得很好:https://blog.csdn.net/qq_39972971/article/details/79882067

用线段树维护 right 集合出现位置,以及区间的 ( sum(r_{i+1}-r_i)*r_{i+1} ) 和 ( sum(r_{i+1}-r_i) ) 。为了维护这个,要记录区间里第一个 endpos (记为 fr ) 和最后一个 endpos (记为 sc )。并且没有统计区间里最后一个 endpos 与它后面那个 endpos 的贡献。

如果想在线的话,线段树合并的时候要在 cr 和 pr 都有的时候新建节点。(不过有点不知道空间怎样)

设子串出现了 m 次。在线段树上二分找可以套用 ( sum(r_{i+1}-r_i)*r_{i+1} - l_m*sum(r_{i+1}-r_i) ) 的 i 的区间。需要满足 ( l_i<=r_1 , r_{i+1}>=l_m ) 即 ( r_i<=r_1+d-1 , r_{i+1}>=r_m-d+1 ) 。

  记 ( L = r_1+d-1 , R = r_m-d+1 ) 。

  如果没有左右孩子中的一个,直接递归进有的那个孩子里。否则令 ( p_0 = vl[ls].sc , p_1 = vl[rs].fr ) ,先看是否 ( p_0 <=L , p_1 >= R ) 。如果在找最左端,当( p_0 , p_1 ) 满足或者 ( p_0>L )的时候就递归进左子树,否则递归进右子树。如果在找最右端,当 ( p_0 , p_1 ) 满足的时候递归进右子树,但右子树可能再没有满足的了,所以得到返回值之后如果不满足就还是把 ( p_0 ) 返回上去;当 ( p_0 , p_1 ) 不满足的时候就看如果 ( p_0 > L ) 就进左子树,不然进右子树。

  得到 i 的区间端点(设为 ( q_l , q_r ) )之后,要判断一下是不是真的合法。因为按上面的过程,就算没有合法位置,也会返回一个位置(并且 ( q_l ) 和 ( q_r ) 会是相同的)。判断方法就是看看 ( q_l ) 是否满足 “ 与 (r_1) 有交 ” 并且 “下一个子串与 (r_m) 有交或者 ( q_l ) 就是 (r_m) ”( (r_1) 和 (r_m) 表示出现的第一个子串位置和第 m 个子串位置)。

  涉及到找一个出现位置的下一个位置,在线段树上找一下即可。

接着判断特殊情况。

  第一刀不过子串或者第二刀不过子串都需要且仅需要 “ ( r_1 ) 和 ( r_m ) 有交” 这个条件。当自己找出的 ( q_r ) == ( r_m ) 的时候说明满足(因为 ( q_r ) 要求与 ( r_1 ) 有交)。

  第一刀不过子串的贡献是 ( ( l_1 - 1 ) * ( r_1 - l_m ) ) ;第二刀不过子串的贡献是 ( sumlimits_{i=n-r_1}^{n-l_m-1} ) 。

  (仔细考虑 “刀” 的意义。刀只能切在两个位置的缝隙间;第一刀表示题目中的 i 选在了这个缝隙左边;第二刀表示题目中的 i 选在了这个缝隙右边;所以方案数是这样,并且直接就是 i + 1 < j 的)

  然后考虑 ( l_{i+1}>r_1 ) 的情况。这个发生在 ( q_r ) 上。如果 ( q_r ) == ( r_m ) 就不用管了,否则一会儿查线段树的时候去掉 ( q_r ) 这个位置,现在给答案加上 ( ( r_1 - l_i ) * ( r_{i+1} - l_m ) ) (这里的 i 指的是 ( q_r ) )。

查线段树的时候想要查上 ( q_r ) 作为 i 时候的贡献,需要查的区间是 ( [ q_l , q_r+1 ] ) 。

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ls Ls[cr]
#define rs Rs[cr]
#define ll long long
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
int Mn(int a,int b){return a<b?a:b;}
int Mx(int a,int b){return a>b?a:b;}
const int N=2e5+5,K=25,M=N*K;
int n,lst=1,cnt=1,len[N],fa[N],go[N][10],tx[N],c[N];
int rt[N],tot,Ls[M],Rs[M],pre[N][K],ps[N]; char s[N];
struct Node{
  int fr,sc; ll s1,s2;
  Node operator+ (const Node &b)const
  {
    if(!fr)return b; if(!b.fr)return *this;
    Node ret; ret.fr=fr; ret.sc=b.sc;
    ret.s1=s1+b.s1; ret.s2=s2+b.s2;
    ret.s1+=(ll)(b.fr-sc)*b.fr; ret.s2+=b.fr-sc;
    return ret;
  }
}vl[M];
void build(int l,int r,int &cr,int p)
{
  cr=++tot; vl[cr].fr=vl[cr].sc=p;
  if(l==r)return; int mid=l+r>>1;
  if(p<=mid)build(l,mid,ls,p); else build(mid+1,r,rs,p);
}
void ins(int w,int bh)
{
  int p=lst,np=++cnt; lst=np; len[np]=len[p]+1; ps[bh]=np;
  build(1,n,rt[np],bh);
  for(;p&&!go[p][w];p=fa[p])go[p][w]=np;
  if(!p){fa[np]=1;return;}
  int q=go[p][w]; if(len[q]==len[p]+1){fa[np]=q;return;}
  int nq=++cnt; len[nq]=len[p]+1;
  fa[nq]=fa[q]; fa[q]=nq; fa[np]=nq;
  memcpy(go[nq],go[q],sizeof go[q]);
  for(;go[p][w]==q;p=fa[p])go[p][w]=nq;
}
void Rsort()
{
  for(int i=1;i<=cnt;i++)tx[len[i]]++;
  for(int i=1;i<=cnt;i++)tx[i]+=tx[i-1];
  for(int i=1;i<=cnt;i++)c[tx[len[i]]--]=i;
}
int mrg(int l,int r,int cr,int pr)
{
  if(!cr||!pr)return cr|pr;
  int tp=++tot,mid=l+r>>1;
  Ls[tp]=mrg(l,mid,ls,Ls[pr]); Rs[tp]=mrg(mid+1,r,rs,Rs[pr]);
  vl[tp]=vl[Ls[tp]]+vl[Rs[tp]];//not ls,rs
  return tp;
}
int get(int l,int r,int cr,int L,int R,int d,bool fx)//r_1,l_m
{
  if(l==r)return l; int mid=l+r>>1;
  if(!ls)return get(mid+1,r,rs,L,R,d,fx);
  if(!rs)return get(l,mid,ls,L,R,d,fx);
  int p0=vl[ls].sc, p1=vl[rs].fr;//r_i,r_{i+1}
  if(!fx)
    {
      if((p0<=L&&p1>=R)||(p0>L))return get(l,mid,ls,L,R,d,fx);
      return get(mid+1,r,rs,L,R,d,fx);
    }
  if(p0<=L&&p1>=R)
    {
      int p=get(mid+1,r,rs,L,R,d,fx);
      if(p<=L)return p; else return p0;
    }
  if(p0>L)return get(l,mid,ls,L,R,d,fx);
  return get(mid+1,r,rs,L,R,d,fx);
}
int fnd(int l,int r,int cr,int L)
{
  if(l==r)return l; int mid=l+r>>1;
  if(mid<L)return fnd(mid+1,r,rs,L);
  if(ls&&vl[ls].sc>=L)return fnd(l,mid,ls,L);
  return fnd(mid+1,r,rs,L);
}
Node qry(int l,int r,int cr,int L,int R)
{
  if(l>=L&&r<=R)return vl[cr];
  int mid=l+r>>1;
  if(L>mid)return qry(mid+1,r,rs,L,R);
  if(R<=mid)return qry(l,mid,ls,L,R);
  Node ret=qry(l,mid,ls,L,R)+qry(mid+1,r,rs,L,R);
  return ret;
}
ll calc(ll a,ll b){return (a+b)*(b-a+1)/2;}
bool chk(int ql,int qr,int r1,int rm,int d,int cr)
{
  if(ql-d+1>r1)return true; if(ql==rm)return false;
  int p=fnd(1,n,rt[cr],ql+1);
  if(p<rm-d+1)return true; return false;
}
int main()
{
  n=rdn();int Q=rdn();
  scanf("%s",s+1);
  for(int i=1;i<=n;i++) ins(s[i]-'0',i);
  Rsort();
  for(int i=cnt;i>1;i--)
    {
      int cr=c[i];
      rt[fa[cr]]=mrg(1,n,rt[fa[cr]],rt[cr]);
    }
  for(int i=2;i<=cnt;i++)
    {
      int cr=c[i]; pre[cr][0]=fa[cr];
      for(int t=1,d=fa[cr];(d=pre[d][t-1]);t++)pre[cr][t]=d;
    }
  int l,r;
  while(Q--)
    {
      l=rdn();r=rdn(); int cr=ps[r],d=r-l+1;ll ans=0;
      for(int t=17;t>=0;t--)
    if(len[pre[cr][t]]>=d)cr=pre[cr][t];
      int r1=vl[rt[cr]].fr, rm=vl[rt[cr]].sc;
      int ql=get(1,n,rt[cr],r1+d-1,rm-d+1,d,0);
      int qr=get(1,n,rt[cr],r1+d-1,rm-d+1,d,1);
      if(chk(ql,qr,r1,rm,d,cr))
    {printf("%lld
",(ll)(n-1)*(n-2)/2);continue;}
      if(qr==rm)//qr==rm: 1&m==1
    {
      ans=(ll)(r1-d)*(r1-rm+d-1);//(l1-1)*(r1-lm)
      ans+=calc(n-r1,n-rm+d-2);//[n-r1,n-lm-1]
    }
      else
    {
      int p=fnd(1,n,rt[cr],qr+1);
      if(p-d+1>r1)
        ans+=(ll)(r1-qr+d-1)*(p-rm+d-1);//(r1-li)*(r_{i+1}-lm)
      else qr=p;//so can cal val of qr
    }
      Node t=qry(1,n,rt[cr],ql,qr);
      ans+=t.s1-(ll)t.s2*(rm-d+1);//-lm
      ans=(ll)(n-1)*(n-2)/2-ans;
      printf("%lld
",ans);
    }
  return 0;
}
原文地址:https://www.cnblogs.com/Narh/p/10623870.html