[HAOI2016]找相同字符【GSAM广义后缀自动机】

题目链接

  对于两个字符串,我们想知道他们有多少不同的公共子串,不妨可以考虑成对于戴尔个串的每个不同的后缀,有多少个相同后缀子串。

  于是可以考虑成为,我们对于第一个串先建立一个后缀树(link树),然后对于第二个串,我们在第一个后缀树上跑,来求答案,但是第二个串要怎么跑呢?我们不妨将第二个串也插入到后缀树上去。

  这样以来,两个串都在后缀树上了,给第一个串设立成为“存在点”,将第二个串设置成为“查询点”。然后我们只需要在后缀树上dfs跑一个贡献就可以了,由于后缀树实际上是路径压缩的,所以不同的点的个数,实际上还要乘以它和它的父亲节点(link节点)的长度差值。

  1 #include <iostream>
  2 #include <cstdio>
  3 #include <cmath>
  4 #include <string>
  5 #include <cstring>
  6 #include <algorithm>
  7 #include <limits>
  8 #include <vector>
  9 #include <stack>
 10 #include <queue>
 11 #include <set>
 12 #include <map>
 13 #include <bitset>
 14 #include <unordered_map>
 15 #include <unordered_set>
 16 #define lowbit(x) ( x&(-x) )
 17 #define pi 3.141592653589793
 18 #define e 2.718281828459045
 19 #define INF 0x3f3f3f3f
 20 #define HalF (l + r)>>1
 21 #define lsn rt<<1
 22 #define rsn rt<<1|1
 23 #define Lson lsn, l, mid
 24 #define Rson rsn, mid+1, r
 25 #define QL Lson, ql, qr
 26 #define QR Rson, ql, qr
 27 #define myself rt, l, r
 28 #define pii pair<int, int>
 29 #define MP(a, b) make_pair(a, b)
 30 using namespace std;
 31 typedef unsigned long long ull;
 32 typedef unsigned int uit;
 33 typedef long long ll;
 34 const int maxN = 4e5 + 7;
 35 const int maxP = maxN << 1;
 36 ll ans;
 37 struct SAM
 38 {
 39     struct state
 40     {
 41         int len, link, next[26];
 42     } st[maxP];
 43     int siz = 1, last;
 44     int dp[maxP] = {0}, query[maxP] = {0};
 45     void init()
 46     {
 47         siz = last = 1;
 48         st[1].len = 0;
 49         st[1].link = 0;
 50         memset(st[1].next, 0, sizeof(st[1].next));
 51         siz++;
 52     }
 53     int extend(int c, int val, int question)
 54     {
 55         if(st[last].next[c] && st[last].len + 1 == st[st[last].next[c]].len)
 56         {
 57             last = st[last].next[c];
 58             dp[last] += val;
 59             query[last] += question;
 60             return last;
 61         }
 62         int cur = siz++;
 63         st[cur].len = st[last].len + 1;
 64         dp[cur] += val;
 65         query[cur] += question;
 66         int p = last;
 67         while (p && !st[p].next[c])
 68         {
 69             st[p].next[c] = cur;
 70             p = st[p].link;
 71         }
 72         if (p == 0)
 73         {
 74             st[cur].link = 1;
 75         }
 76         else
 77         {
 78             int q = st[p].next[c];
 79             if (st[p].len + 1 == st[q].len)
 80             {
 81                 st[cur].link = q;
 82             }
 83             else
 84             {
 85                 int clone;
 86                 if(p == last)
 87                 {
 88                     clone = cur;
 89                 }
 90                 else
 91                 {
 92                     clone = siz++;
 93                     st[cur].link = clone;
 94                 }
 95                 st[clone] = st[q];
 96                 st[q].link = clone;
 97                 st[clone].len = st[p].len + 1;
 98                 while (p != 0 && st[p].next[c] == q)
 99                 {
100                     st[p].next[c] = clone;
101                     p = st[p].link;
102                 }
103             }
104         }
105         return last = cur;
106     }
107     vector<int> to[maxP];
108     void bfs()
109     {
110         for(int i=2; i<siz; i++) to[st[i].link].push_back(i);
111     }
112     void dfs(int u)
113     {
114         for(int v : to[u])
115         {
116             dfs(v);
117             dp[u] += dp[v];
118             query[u] += query[v];
119         }
120         if(u ^ 1) ans += 1LL * query[u] * dp[u] * (st[u].len - st[st[u].link].len);
121     }
122 } sam;
123 int main()
124 {
125     int n1, n2;
126     char s[maxN];
127     scanf("%s", s);
128     n1 = (int)strlen(s);
129     sam.init();
130     for(int i=0; i<n1; i++)
131     {
132         sam.extend(s[i] - 'a', 1, 0);
133     }
134     scanf("%s", s);
135     n2 = (int)strlen(s);
136     sam.last = 1;
137     for(int i=0; i<n2; i++)
138     {
139         sam.extend(s[i] - 'a', 0, 1);
140     }
141     sam.bfs();
142     ans = 0;
143     sam.dfs(1);
144     printf("%lld
", ans);
145     return 0;
146 }
147 /*
148 ababaa
149 aba
150 ans:16
151 */
原文地址:https://www.cnblogs.com/WuliWuliiii/p/13710122.html