[HAOI2016]找相同字符

题目大意:
  给你两个字符串a和b,要求从a和b中各取出一个相等的子串,问不同的取法有多少种。

思路:
  对于a串建立SAM,然后DP求出每个状态Right集合的大小。
  然后把b串放进去匹配,对于每一个匹配到的结点p,它的每一个Right状态都可以匹配一个长度为tmp-s[s[p].link].len的串,那么将s[p].right*(tmp-s[s[p].link].len)计入答案。
  统计每个状态被匹配的次数,并累加进它Parent中。
  将s[top[i]].right*f[top[i]]*(s[top[i]].len-s[s[top[i]].link].len)统计进答案。
  然后发现样例都过不了,不管什么数据都输出0。发现是匹配的时候的w直接用ch复制,而没有用idx(这样就导致b串一点都匹配不了)。
  再交上去发现还是WA了,考虑两个字符串完全相等的情况,答案可以达到$displaystyle{sum_{i=1}^{n}}n^2$,显然要开long long,然后就过了。

  1 #include<cstdio>
  2 #include<cstring>
  3 const int LEN=200001;
  4 char s[LEN];
  5 class SuffixAutomaton {
  6     private:
  7         static const int SIGMA_SIZE=26;
  8         struct State {
  9             int link,trans[SIGMA_SIZE],len,right,app;
 10         };
 11         State s[LEN<<1];
 12         int root,last;
 13         int len;
 14         int sz,newState(const int l) {
 15             sz++;
 16             s[sz].len=l;
 17             return sz;
 18         }
 19         int idx(const char ch) {
 20             return ch-'a';
 21         }
 22         void extend(const char ch) {
 23             const int w=idx(ch);
 24             int p=last,new_p=newState(s[last].len+1);
 25             s[new_p].right=1;
 26             while(p&&!s[p].trans[w]) {
 27                 s[p].trans[w]=new_p;
 28                 p=s[p].link;
 29             }
 30             if(!p) {
 31                 s[new_p].link=root;
 32             } else {
 33                 int q=s[p].trans[w];
 34                 if(s[q].len==s[p].len+1) {
 35                     s[new_p].link=q;
 36                 } else {
 37                     int new_q=newState(s[p].len+1);
 38                     memcpy(s[new_q].trans,s[q].trans,sizeof s[q].trans);
 39                     s[new_q].link=s[q].link;
 40                     s[q].link=s[new_p].link=new_q;
 41                     while(p&&s[p].trans[w]==q) {
 42                         s[p].trans[w]=new_q;
 43                         p=s[p].link;
 44                     }
 45                 }
 46             }
 47             last=new_p;
 48         }
 49         int cnt[LEN],top[LEN<<1],f[LEN<<1];
 50         void tsort() {
 51             for(int i=1;i<=sz;i++) {
 52                 cnt[s[i].len]++;
 53             }
 54             for(int i=len;i;i--) {
 55                 cnt[i-1]+=cnt[i];
 56             }
 57             for(int i=1;i<=sz;i++) {
 58                 top[cnt[s[i].len]--]=i;
 59             }
 60         }
 61     public:
 62         void build(char str[]) {
 63             len=strlen(str);
 64             root=last=newState(0);
 65             for(int i=0;str[i];i++) {
 66                 extend(str[i]);
 67             }
 68         }
 69         long long query(char str[]) {
 70             tsort();
 71             for(int i=1;i<=sz;i++) {
 72                 if(s[top[i]].link) {
 73                     s[s[top[i]].link].right+=s[top[i]].right;
 74                 }
 75             }
 76             long long ans=0,n=strlen(str),tmp=0;
 77             int p=root;
 78             for(int i=0;i<n;i++) {
 79                 const int w=idx(str[i]);
 80                 if(s[p].trans[w]) {
 81                     p=s[p].trans[w];
 82                     tmp++;
 83                 } else {
 84                     while(p&&!s[p].trans[w]) {
 85                         p=s[p].link;
 86                     }
 87                     if(!p) {
 88                         p=root;
 89                         tmp=0;
 90                     } else {
 91                         tmp=s[p].len+1;
 92                         p=s[p].trans[w];
 93                     }
 94                 }
 95                 s[p].app++,ans+=(long long)s[p].right*(tmp-s[s[p].link].len);
 96             }
 97             for(int i=1;i<=sz;i++) {
 98                 f[s[top[i]].link]+=f[top[i]]+s[top[i]].app;
 99             }
100             for(int i=1;i<sz;i++) {
101                 ans+=(long long)s[top[i]].right*f[top[i]]*(s[top[i]].len-s[s[top[i]].link].len);
102             }
103             return ans;
104         }
105 };
106 SuffixAutomaton sam;
107 int main() {
108     scanf("%s",s);
109     sam.build(s);
110     scanf("%s",s);
111     printf("%lld
",sam.query(s));
112     return 0;
113 }
原文地址:https://www.cnblogs.com/skylee03/p/7528373.html