【poj3415-长度不小于k的公共子串个数】后缀数组+单调栈

这题曾经用sam打过,现在学sa再来做一遍。

基本思路:计算A所有的后缀和B所有后缀之间的最长公共前缀。

分组之后,假设现在是做B的后缀。前面的串能和当前的B后缀产生的公共前缀必定是从前往后单调递增的,每次与h[i]取min时必定将栈尾一些长的全部取出来,搞成一个短的。

所以就开一个栈,栈里存的是长度,同时存一下它的出现此处cnt。

注意各种细节啊。。

  1 #include<cstdio>
  2 #include<cstdlib>
  3 #include<cstring>
  4 #include<iostream>
  5 using namespace std;
  6 
  7 typedef long long LL;
  8 const int N=2*100010;
  9 int K,sl,cl,sa[N],rk[N],Rs[N],wr[N],y[N],h[N];
 10 LL sk[N],cnt[N];
 11 char s[N],c[N];
 12 
 13 void get_sa(int m)
 14 {
 15     for(int i=1;i<=cl;i++) rk[i]=c[i];
 16     for(int i=1;i<=m;i++) Rs[i]=0;
 17     for(int i=1;i<=cl;i++) Rs[rk[i]]++;
 18     for(int i=1;i<=m;i++) Rs[i]+=Rs[i-1];
 19     for(int i=cl;i>=1;i--) sa[Rs[rk[i]]--]=i;
 20     
 21     int ln=1,p=0;
 22     while(p<cl)
 23     {
 24         int k=0;
 25         for(int i=cl-ln+1;i<=cl;i++) y[++k]=i;
 26         for(int i=1;i<=cl;i++) if(sa[i]>ln) y[++k]=sa[i]-ln;
 27         
 28         for(int i=1;i<=cl;i++) wr[i]=rk[y[i]];
 29         for(int i=1;i<=m;i++) Rs[i]=0;
 30         for(int i=1;i<=cl;i++) Rs[wr[i]]++;
 31         for(int i=1;i<=m;i++) Rs[i]+=Rs[i-1];
 32         for(int i=cl;i>=1;i--) sa[Rs[wr[i]]--]=y[i];
 33         
 34         for(int i=1;i<=cl;i++) wr[i]=rk[i];
 35         for(int i=cl+1;i<=cl+ln;i++) wr[i]=0;
 36         p=1;rk[sa[1]]=1;
 37         for(int i=2;i<=cl;i++)
 38         {
 39             if(wr[sa[i]]!=wr[sa[i-1]] || wr[sa[i]+ln]!=wr[sa[i-1]+ln]) p++;
 40             rk[sa[i]]=p;
 41         }
 42         ln*=2,m=p;
 43     }
 44     sa[0]=0,rk[0]=0;
 45 }
 46 
 47 void get_h()
 48 {
 49     int k=0,j;
 50     for(int i=1;i<=cl;i++) if(rk[i]!=1)
 51     {
 52         j=sa[rk[i]-1];
 53         if(k) k--;
 54         while(c[i+k]==c[j+k] && i+k<=cl && j+k<=cl) k++;
 55         h[rk[i]]=k;
 56     }
 57     h[1]=0;
 58 }
 59 
 60 void init()
 61 {
 62     int i,tl;cl=0;
 63     scanf("%s",s+1);
 64     tl=strlen(s+1);sl=tl;
 65     for(i=1;i<=sl;i++) c[++cl]=s[i];
 66     scanf("%s",s+1);
 67     tl=strlen(s+1);
 68     c[++cl]='#';
 69     for(i=1;i<=sl;i++) c[++cl]=s[i];
 70 }
 71 
 72 bool check(int x,int tmp)
 73 {
 74     if(tmp==0) return (x<=sl) ? 0:1;
 75     else       return (x<=sl) ? 1:0;
 76 }
 77 
 78 LL solve(int tmp)
 79 {
 80     int tot=0;
 81     LL sum=0,ans=0;
 82     memset(sk,0,sizeof(sk));
 83     memset(cnt,0,sizeof(cnt));
 84     for(int i=1;i<=cl;i++)
 85     {
 86         if(h[i]<K)
 87         {
 88             for(int j=1;j<=tot;j++) cnt[j]=0;
 89             tot=0;sum=0;
 90         }
 91         else
 92         {
 93             int tcnt=0,tsum=0;
 94             while(sk[tot] > h[i]-K+1)
 95             {
 96                 tcnt+=cnt[tot];
 97                 tsum+=cnt[tot]*sk[tot];
 98                 sk[tot]=0,cnt[tot]=0;
 99                 tot--;
100             }
101             if(tcnt)
102             {
103                 sk[++tot]=h[i]-K+1;
104                 cnt[tot]=tcnt;
105                 sum=sum-tsum+tcnt*sk[tot];
106             }
107             if(check(sa[i],tmp)) ans+=sum;    
108         }
109         if(!check(sa[i],tmp) && (cl-sa[i]+1>=K))
110         {
111             sk[++tot]=(cl-sa[i]+1)-K+1;
112             cnt[tot]++;
113             sum+=sk[tot];
114         }
115     }
116     return ans;
117 }
118 
119 int main()
120 {
121     freopen("a.in","r",stdin);
122     freopen("me.out","w",stdout);
123     while(1)
124     {
125         scanf("%d",&K);
126         if(!K) return 0;
127         init();
128         get_sa(200);
129         get_h();
130         // for(int i=1;i<=cl;i++) printf("%d ",sa[i]);printf("
");
131         // for(int i=1;i<=cl;i++) printf("%d ",rk[i]);printf("
");
132         // for(int i=1;i<=cl;i++) 
133         // {
134             // for(int j=sa[i];j<=cl;j++) printf("%c",c[j]);printf("
");
135         // }
136         printf("%I64d
",solve(0)+solve(1));
137     }
138     return 0;
139 }
原文地址:https://www.cnblogs.com/KonjakJuruo/p/5917762.html