NOI2018 你的名字

传送门

这道题确实是相当优美的……
以下S串指的是ION2017的串,T串指的是ION2018的串。
题目要求的是,两个字符串本质不同的非公共串的个数。这个转化一下,只要求两个字符串本质不同的公共串的个数即可,因为我们用本质不同的串数减一下就可以了。这个大家都会求。
首先我们考虑l=1,r=|S|的情况。这个的话,我们考虑对于每一个T串,求出它的每一个前缀(t_i)在S串中能被匹配的长度。这个还是比较好做的,我们只要先把S串的SAM建出来,之后对于每一个T串也把它的SAM建出来,额外记录SAM上的每个节点对应原串的位置。首先把T在S串上跑匹配,记录每个位置上的匹配长度(ans[i]),之后再在T串的SAM上跑一遍,每个节点对答案的贡献就是((0,ans[pos[i]])([l[fa[i]],l[i]])的交集包含的整数个数。pos[i]是当前SAM节点对应的原串位置。
先看一下暴力的68分代码。

#include<bits/stdc++.h>
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define per(i,n,a) for(int i = n;i >= a;i--)
#define enter putchar('
')
#define pr pair<int,int>
#define mp make_pair
#define fi first
#define sc second
using namespace std;
typedef long long ll;
const int mod = 64123;
const int M = 500005;
const int N = 5000005;
 
int read()
{
   int ans = 0,op = 1;char ch = getchar();
   while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
   while(ch >='0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
   return ans * op;
}

char s[M],t[M];
int a[M<<1],c[M<<1],size[M<<1],n,len,q,L,R,m;
int ans[M<<1],lim[M<<1],flim[M<<1];

struct Suffix
{
   int last,cnt,ch[M<<2][26],fa[M<<1],l[M<<1],rpos[M<<1];
   void clear()
   {
      rep(i,1,cnt)
      {
     rep(j,0,25) ch[i][j] = 0;
     fa[i] = l[i] = ans[i] = rpos[i] = 0;
      }
      cnt = last = 1;
   }
   void extend(int pos,int c)
   {
      int p = last,np = ++cnt;
      last = cnt,l[np] = l[p] + 1,rpos[np] = pos;
      while(p && !ch[p][c]) ch[p][c] = np,p = fa[p];
      if(!p) fa[np] = 1;
      else
      {
     int q = ch[p][c];
     if(l[p] + 1 == l[q]) fa[np] = q;
     else
     {
        int nq = ++cnt;
        l[nq] = l[p] + 1,memcpy(ch[nq],ch[q],sizeof(ch[q]));
        fa[nq] = fa[q],fa[q] = fa[np] = nq,rpos[nq] = rpos[q];
        while(ch[p][c] == q) ch[p][c] = nq,p = fa[p];
     }
      }
   }
   void match(char *t)
   {
      int p = 1,cur,la = 0;
      rep(i,1,m)
      {
     int c = t[i] - 'a';
     while(p != 1 && !ch[p][c]) p = fa[p];
     cur = min(l[p],la);
     if(ch[p][c]) cur++,p = ch[p][c];
     ans[i] = la = cur;
      }
   }
   ll run()
   {
      ll now = 0;
      rep(i,2,cnt)
      {
     now += l[i] - l[fa[i]];
     if(ans[rpos[i]] > l[fa[i]]) now -= min(ans[rpos[i]],l[i]) - l[fa[i]];
      }
      return now;
   }
}S,T;
 
int main()
{
   scanf("%s",s+1),len = strlen(s+1),S.clear();
   rep(i,1,len) S.extend(i,s[i] - 'a');
   q = read();
   while(q--)
   {
      scanf("%s",t+1),L = read(),R = read(),m = strlen(t+1);
      T.clear(),S.match(t);
      rep(i,1,m) T.extend(i,t[i] - 'a');
      ll now = T.run();
      printf("%lld
",now);
   }
   return 0;
}

然后我们考虑100分做法。其实大体的思路还没有变,我们依然这样考虑,假设现在已经知道了T串的一个前缀(t_i)在S串中的匹配,那么如何求出(t_{i+1})呢?首先肯定我们要在S串上找到这个转移,如果没有就跳父亲……在后面加字符是和正常匹配一样的,现在问题在于我们要确定它是处于S串的([l,r])区间里的。这样的话一旦字符串的前面不符合的话我们是需要减去前面的限制的。如果我们已经知道对于每个状态对应的endpos集合,那么我们只要在这个状态上查找在([l+len,r])上是否有一个endpos即可,其中len是当前匹配长度。如果有的话就说明当前的串是在S中可以被匹配的。(这个地方可能需要多考虑一下,简而言之就是如果存在一个endpos,就说明当前状态对应了一个结束位置在[l+len,r]的子串,因为这个串长度为len,所以一定是被完整的包含在[l,r]之中的)如果没有的话就减少当前匹配长度len,如果减少到一定长度,就跳到自己的父亲上,这样做下去就可以了。
然后每个状态对应的endpos集合用线段树合并即可。
最后每次我们从当前匹配的点倒着跑回去计算一下答案。注意实现的时候,因为每个点对答案的贡献只算一次,所以我们要实时更改当前的值。这个详细的看代码就可以了。

这题确实是相当精妙啊……

#include<bits/stdc++.h>
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define per(i,n,a) for(int i = n;i >= a;i--)
#define enter putchar('
')
#define pr pair<int,int>
#define mp make_pair
#define fi first
#define sc second
using namespace std;
typedef long long ll;
const int mod = 64123;
const int M = 500005;
const int N = 5000005;
 
int read()
{
   int ans = 0,op = 1;char ch = getchar();
   while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
   while(ch >='0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
   return ans * op;
}

char s[M],h[M];
int a[M<<1],c[M<<1],n,q,L,R,m,root[M<<1],tot,cur;

struct tree
{
   int lc,rc,val;
}t[M*60];

void modify(int &p,int l,int r,int pos)
{
   t[p = ++cur].val++;
   if(l == r) return;
   int mid = (l+r) >> 1;
   if(pos <= mid) modify(t[p].lc,l,mid,pos);
   else modify(t[p].rc,mid+1,r,pos);
}

int query(int p,int l,int r,int kl,int kr)
{
   if(!p || l > r) return 0;
   if(l == kl && r == kr) return t[p].val;
   int mid = (l+r) >> 1;
   if(kr <= mid) return query(t[p].lc,l,mid,kl,kr);
   else if(kl > mid) return query(t[p].rc,mid+1,r,kl,kr);
   else return query(t[p].lc,l,mid,kl,mid) + query(t[p].rc,mid+1,r,mid+1,kr);
}

int merge(int a,int b)
{
   if(!a || !b) return a | b;
   int p = ++cur;
   t[p].val = t[a].val + t[b].val;
   t[p].lc = merge(t[a].lc,t[b].lc);
   t[p].rc = merge(t[a].rc,t[b].rc);
   return p;
}

struct Suffix
{
   int last,cnt,ch[M<<2][26],fa[M<<1],l[M<<1],rpos[M<<1],f[M<<1];
   void clear()
   {
      rep(i,1,cnt)
      {
     rep(j,0,25) ch[i][j] = 0;
     fa[i] = l[i] = f[i] = rpos[i] = 0;
      }
      cnt = last = 1;
   }
   void extend(int v,int c)
   {
      int p = last,np = ++cnt;
      last = np,l[np] = l[p] + 1;
      while(p && !ch[p][c]) ch[p][c] = np,p = fa[p];
      if(!p) fa[np] = 1;
      else
      {
     int q = ch[p][c];
     if(l[p] + 1 == l[q]) fa[np] = q;
     else
     {
        int nq = ++cnt;
        l[nq] = l[p] + 1,memcpy(ch[nq],ch[q],sizeof(ch[q]));
        fa[nq] = fa[q],fa[q] = fa[np] = nq;
        while(ch[p][c] == q) ch[p][c] = nq,p = fa[p];
     }
      }
      if(v) modify(root[np],1,n,v),rpos[np] = v;
   }
   void cal()
   {
      rep(i,1,cnt) c[l[i]]++;
      rep(i,1,cnt) c[i] += c[i-1];
      rep(i,1,cnt) a[c[l[i]]--] = i;
      per(i,cnt,1) root[fa[a[i]]] = merge(root[fa[a[i]]],root[a[i]]);
   }
}S,T;
 
int main()
{
   scanf("%s",s+1),n = strlen(s+1),S.clear();
   rep(i,1,n) S.extend(i,s[i] - 'a');
   S.cal(),q = read();
   while(q--)
   {
      scanf("%s",h+1),L = read(),R = read(),m = strlen(h+1),tot = 0,T.clear();
      rep(i,1,m) T.extend(0,h[i] - 'a');
      rep(i,1,T.cnt) T.f[i] = T.l[i];
      int u1 = 1,u2 = 1;
      ll now = 0;
      rep(i,1,m)
      {
     int c = h[i] - 'a';
     u2 = T.ch[u2][c];
     if(S.ch[u1][c] && query(root[S.ch[u1][c]],1,n,L+tot,R)) tot++,u1 = S.ch[u1][c];
     else
     {
        while(u1 && ((!S.ch[u1][c]) || (!query(root[S.ch[u1][c]],1,n,L+tot,R))))
        {
           if(!tot) {u1 = 0;break;}
           tot--;
           if(tot == S.l[S.fa[u1]]) u1 = S.fa[u1];
        }
        if(!u1) tot = 0,u1 = 1;
        else tot++,u1 = S.ch[u1][c];
     }
     int x = u2;
     while(x)
     {
        if(tot <= T.l[T.fa[x]]) now += T.f[x] - T.l[T.fa[x]],T.f[x] = T.l[T.fa[x]];
        else {now += T.f[x] - tot,T.f[x] = tot;break;}
        x = T.fa[x];
     }
      }
      printf("%lld
",now);
   }
   return 0;
}

原文地址:https://www.cnblogs.com/captain1/p/10527019.html