KMP学习笔记

这是一个比较难以理解的算法,但是理解之后可以跑的飞快~

感谢Δx大佬耐心的讲解~顺便免费广告:https://www.cnblogs.com/popo-black-cat/

好的讲正题:

首先考虑这样一个问题:已知两个字符串s1,s2,s2是s1的子串,统计s2在s1中出现的次数。s1长度为n,s2长度为m。

如果单纯考虑暴力的话,有这样的算法,枚举s1的每一位,以这一位为起点往下枚举s2匹配,之后再枚举s1的下一位。一个显然的事实是这个算法对s1每一位都匹配了m次,因此总复杂度为O(n * m)。

复杂度显然不行,于是KMP算法就出现了。KMP本质就是对这个暴力算法进行优化,利用每次匹配失败的信息优化匹配,这样就不必从s2的第一位开始枚举,而是从前一位匹配上的位置进行继续匹配,因此时间复杂度为O(n + m)。

优化的关键在于维护一个nxt数组,nxt[i]表示s2在0~i位之间,后缀集合和前缀集合的交集的最大子串长度,也就是那个最大的前缀子串的最后一位。首先解释一下后缀集合和前缀集合,举个例子,对于字符串“ababqwq”来说,它的前缀集合就是{"a","ab","aba","abab","ababq","ababqw"},后缀集合就是{"q","wq","qwq","bqwq",abqwq","babqwq"},注意前后缀均不包括字符串本身。然后我们就可以开始求nxt数组了。

nxt[0]和nxt[1]显然都是0,这个是初始化。然后我们分情况讨论一下,假设到第i位,第i-1位和第nxt[i - 1]位匹配上,问题的关键就在于nxt[i - 1] + 1位和第i位是否匹配,如果匹配那么nxt[i]就是nxt[i - 1] + 1,否则就往下找,继续找nxt[nxt[i - 1]] + 1和i是否匹配,因为要保证前面的始终匹配上。这样不断找下去总能找到nxt[i]的值,这样我们就用O(m)的复杂度线性维护了nxt数组。

之后再匹配s1就好办多了,枚举s1的每一位再枚举s2,如果没有匹配上则直接从nxt数组取出对应值向下继续匹配,我们没有回到s2开始匹配,一直线性处理,因此这时复杂度为O(n)。

所以总复杂度为O(n + m)。

 1 #include<cstdio>
 2 #include<cstring>
 3 #include<cmath>
 4 #include<algorithm>
 5 #include<iostream>
 6 #define B cout << "Break" << endl;
 7 #define N 1000010
 8 using namespace std;
 9 int nxt[N];
10 char s[N],s2[N];
11 int ls,ls2;
12 void pre_nxt()
13 {
14     nxt[1] = 0;
15     for(int i = 2;i <= ls2;i++)
16     {
17         int t = nxt[i - 1];
18         for(;;t = nxt[t])
19         {
20             if(t == 0)
21             {
22                 if(s2[1] == s2[i]) t = 1;
23                 else t = 0;
24                 break;
25             }
26             if(s2[t + 1] == s2[i])
27             {
28                 t++;
29                 break;
30             }
31         }
32         nxt[i] = t;
33     }
34 }
35 int ans,pos[N];
36 void kmp()
37 {
38     for(int i = 1,j = 0;i <= ls;i++)
39     {
40         if(s2[j + 1] != s[i]) j = nxt[j];
41         if(s2[j + 1] == s[i]) j++;
42         if(j == ls2)
43         {
44             j = nxt[j];
45             ans++;
46             pos[ans] = i - ls2 + 1;
47         }
48     }
49 }
50 int main()
51 {
52     scanf("%s
%s",s + 1,s2 + 1);
53     ls = strlen(s + 1);ls2 = strlen(s2 + 1);
54     pre_nxt();
55     kmp();
56     for(int i = 1;i <= ans;i++) printf("%d
",pos[i]);
57     for(int i = 1;i <= ls2;i++) printf("%d ",nxt[i]);
58 }
参考代码
原文地址:https://www.cnblogs.com/lijilai-oi/p/10801640.html