「字符串算法」第3章 KMP 算法课堂过关

「字符串算法」第3章 KMP 算法课堂过关

关于KMP

upd on 2021/4/1 优化一些细节

声明:本文的字符串下标均从1开始,对于某个字符串a,a.substr(i,j)表示a从第i位开始,长度为j的字串

模板题

传送门

KMP算法的大致原理

个人认为其他博客已经讲得很好,这里简单讲,把重点放在next数组上

先推几篇博客:

首先,我们把模板题中的(s_1)串称为文本串,重命名为(s)(s_2)称为模式串,重命名为(t)(本文中不区分s与t的大小写)

(n)(s)的长度,(m)(t)的长度(会在代码片中出现)

看图,在第一轮匹配中,匹配到了一个不相等的位置,如果用暴力,那就是从头再匹配,但是可以看到(t)串中有一段重复的“ABC”,无需重复匹配,所以第二轮直接跳到如图所示的位置比较两个蓝色的部分

这就是KMP算法的大致思路

next数组

定义

看了KMP的大致原理,相信大家都产生了疑问:我怎么知道要让T串跳到哪个位置呢?这就要用到next数组了,这是KMP的核心,也是难点

先不用管怎么求next数组,看定义(我自己写的):

(j=next_i),则有(j<i)(t.substr(1,j)==t.substr(i-j+1,i)),且对于任意(k(j<k<i)),(t.substr(1,k)≠t.substr(i-k+1,k))

也就是说,next[i]表示“T中以i结尾的非前缀字串”与“T的前缀”能匹配的最长长度,当不存在这样的j时,next[i]=0

举个例子:

若T="ABCDABCE",则对应的next={0 0 0 0 1 2 3 0}

应用

根据next数组的定义,next中存储的是长度,但是由于它是T的某个前缀字串的长度,我们也可以将next当做下标使用(一定要弄清楚,不然后面很蒙)

仍然用上面的图片真懒呐

设S的指针为i,T的指针为j,表示当前完成匹配的位置(也就是说S[i]和T[j]是相等的)

第一轮匹配中,当(j==7)时,我们发现(t)的下一位和(s)的下一位不等,但是(t)的第57位和13位是一样的,即next[7]=3,所以我们需要将(t)的指针(j)跳到第3位,也就是j=next[j],这里有一些细节不是很好理解,KMP在实现时是很巧妙的,我们放到整段代码理解:

		while(j != 0 && s[i] != t[j+1])
			j = next[j];
		if(s[i] == t[j+1])
			j++;
		if(j == m){//j==m标志着已经全部完成匹配
			printf("%d
",i - m + 1);
			j = next[j];
		}

求法

这里是整个KMP最难理解的部分,所以放到最后

先贴出代码

	next[1] = 0;//初始化
	for(int i = 2 , j = 0 ; i <= m ; i++){
		while(j != 0 && t[j+1] != t[i])
			j=next[j];//全算法最confusing的语句
		if(t[j+1] == t[i])
			j++;
		next[i] = j;
	}

考虑暴力枚举:最外层循环枚举每一位(i),第二层枚举next[i],里层判断第二层枚举的是否合法

显然,时间复杂度是在(O(n^2)~O(n^3)),还不如(O(ncdot m))的暴力匹配

优化求法:

先提前声明:求next[i]是要用到next[1~i-1]的,所以我们要从前向后顺序枚举i

定义“候选项”的概念(可能跟《算法竞赛……》的不大一样):如果j满足 t.substr(1,j)==t.substr(i-1-j-1,j)&&j<i-1则j是next[i]的一个候选项

例子:

绿色表示相等的两个字串,则j是next[i]的一个候选项,若标成蓝色的两个字符相等,则候选项j是合法的,next[i]就是所有合法的(j)中的最大值+1

很显然,对于next[i]而言,next[i-1]是它的候选项,但是,问题是next[next[i-1]],next[next[next[i-1]]],......都是候选项,为什么呢?还是看图:

假设next[13]=5,根据(next)的定义,标绿色部分是相等的,再细化一下绿色部分中相等的部分:假设next[5]=2,同理,第二行(不计最上面的下标行)的黄色部分相等,又因为绿色部分相等,我们可以得到第三行的黄色部分都是相等的,再简化为第4行,会发现:这不是和第一行一样了吗(只是长度小了)!

以此类推,可以得到next[i-1],next[next[i-1]],next[next[next[i-1]]],......都是候选项,且他们的值是从左向右递减的,因此,按照这个顺序找到第一个合法的候选值之后,我们就可以确定next[i]

重新看一下代码:

	next[1] = 0;
	for(int i = 2 , j = 0 ; i <= m ; i++){
		while(j != 0 && t[j+1] != t[i])//找到第一个合法的候选项
			j=next[j];//缩小长度
		if(t[j+1] == t[i])
			j++;
		next[i] = j;
	}

发现,每一轮循环没有j=next[i-1]的语句。原因很简单:上一轮结束时语句next[i]=j决定了这一轮刚开始就有j==next[i-1],注意这里的前后的(i)不一样(都不是同一轮循环了)不要学傻了

时间复杂度

上结论:(O(n+m))

(next)数组的求值为例:

	next[1] = 0;
	for(int i = 2 , j = 0 ; i <= m ; i++){
		while(j != 0 && t[j+1] != t[i])
			j=next[j];
		if(t[j+1] == t[i])
			j++;
		next[i] = j;
	}

最外层显然是(O(m))的,问题是里面

while循环中,(j)是递减的,但是又不会变成负数,所以整个过程中,(j)的减小幅度不会超过(j)增加的幅度,而(j)每次才增加1,最多增加(m)次,故(j)的总变化次数不超过(2m),整个时间复杂度近似认为是(O(m))

如果还不能理解,就想像一个平面直角坐标系,(x)轴为(i)(y)轴为(j),从原点出发,(i)每向右一个单位,(j)最多向上一个单位,(j)也可以往下掉(while循环),但不能掉到第四象限,(j)向下掉的高度之和就是while内语句执行的总次数,是绝对不会超过(m)

匹配的循环与上述相近,时间为(O(n+m)),不再赘述

所以,总的时间复杂度为(O(n+m))

模板题代码

不要问模板题输出的最后一行是什么意思,我也不知道,反正输出(next)数组就对了

#include <iostream>
#include <cstdio>
#include <cstring>
#define nn 1000010
using namespace std;
int sread(char s[]) {
	int siz = 1;
	do
		s[siz] = getchar();
	while(s[siz] < 'A' || s[siz] > 'Z');
	while(s[siz] >= 'A' && s[siz] <= 'Z') {
		++siz;
		s[siz] = getchar();
	}
	--siz;
	return siz;
}
char s[nn];
char t[nn];
int next[nn];
int n , m;
int main() {
	n = sread(s);
	m = sread(t);
	next[1] = 0;
	for(int i = 2 , j = 0 ; i <= m ; i++){
		while(j != 0 && t[j+1] != t[i])
			j=next[j];
		if(t[j+1] == t[i])
			j++;
		next[i] = j;
	}
	for(int i = 1 , j = 0 ; i <=n ; i++){
		while(j != 0 && s[i] != t[j+1])
			j = next[j];
		if(s[i] == t[j+1])
			j++;
		if(j == m){
			printf("%d
",i - m + 1);
			j = next[j];
		}
	}
	for(int i = 1 ; i <= m ; i++)
		printf("%d " , next[i]);
	return 0;
}

A. 【例题1】子串查找

题目

代码

#include <iostream>
#include <cstdio>
#define nn 1000010
using namespace std;
int ssiz , tsiz;
int sread(char *s) {
	int siz = 0;
	char c = getchar();
	while(!((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')))
		c = getchar();
		
	while((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'))
		s[++siz] = c , c = getchar();
	return siz;
}
char s[nn] , t[nn];
int nxt[nn];
int main() {
	ssiz = sread(s);
	tsiz = sread(t);
	
	nxt[1] = 0;
	for(int i = 2 ; i <= tsiz ; i++) {
		int j = nxt[i - 1];
		while(j != 0 && t[j + 1] != t[i])
			j = nxt[j];
		nxt[i] = j + 1;
	}
	
	int ans = 0;
	for(int i = 1 , j = 0 ; i <= ssiz ; i++) {
		while(j != 0 && s[i] != t[j + 1])
			j = nxt[j];
		if(s[i] == t[j + 1])
			++j;
		if(j == tsiz)
			j = nxt[j] , ++ans;
	}
	cout << ans;
	return 0;
}

B. 【例题2】重复子串

题目

题目有误

输入若干行,每行有一个字符串,字符串仅含英文字母。特别的,字符串可能为.即一个半角句号,此时输入结束。

第五组数据的字符串包含数字字符,有图为证:

思路

设字符串长度为(siz)

Hash

关于字符串Hash

不难想也很好写的一种方法

直接枚举最小周期长度(i),显然,(siz)一定是(i)的倍数,所以,这只需要(O(sqrt n))的时间复杂度

假设我们已经枚举到(p)的因数(x),就可以直接用(O(frac{siz}{x}))的时间复杂度验证该子字符串是否是周期,代码如下:

inline bool check(int x) {
	ul key = hs[x];
	for(int i = x + 1 ; i + x - 1 <= siz ; i += x)
		if(hs[i + x - 1] - hs[i - 1] * pw[x] != key)//获取字符串s从下标i开始,长度为x的子串的Hash值 , 判断和key是否相等
			return false;
	return true;
}

KMP

下面讲好些不太好想的KMP做法

先上结论:
命名输入进来的字符串为(S),预处理得到(S)(nxt)数组
(siz\%(siz-nxt_{siz})==0),则(siz-nxt_{siz})(S)的最小周期,也就是说,此时答案为(siz / (siz - nxt_{siz}))
否则,答案为"1"

献上图解:

代码

Hash

#include <iostream>
#include <cstdio>
#include <cstring>
#define nn 1000010
#define ul unsigned long long
using namespace std;
#define value(_) (_ >= 'A' && _ <= 'Z' ? (1 + _ - 'A') : (_ >= 'a' && _ <= 'z' ? (27 + _ - 'a') : (_ - '0' + 53) ))
const ul p = 131;

ul hs[nn];
ul pw[nn];
int siz;
char c[nn];

int sread(char *s) {
	int siz = 0;
	char c = getchar();
	if(c == '.')return -1;
	while(!((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')))
		if((c = getchar()) == '.')	return -1;
		
	while((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9'))
		s[++siz] = c , c = getchar();
	return siz;
}
inline bool check(int x) {
	ul key = hs[x];
	for(int i = x + 1 ; i + x - 1 <= siz ; i += x)
		if(hs[i + x - 1] - hs[i - 1] * pw[x] != key)
			return false;
	return true;
}
int main() {
	pw[0] = 1;
	for(int i = 1 ; i <= nn - 5 ; i++)
		pw[i] = pw[i - 1] * p;
	while((siz = sread(c)) != -1) {
		for(int i = 1 ; i <= siz ; i++)
			hs[i] = hs[i - 1] * p + value(c[i]);
		
		int ans = 0;
		for(int i = 1 ; i * i <= siz ; i++) {
			if(siz % i == 0)
				if(check(i)) {
					ans = i;
					break;
				}
				else {
					if(check(siz / i))
						ans = siz / i;
				}
		}
		printf("%d
" , siz / ans);
		memset(c, 0  , sizeof(c));
		memset(hs , 0 , sizeof(hs));
	}
	return 0;
}

KMP

#include <iostream>
#include <cstdio>
#include <cstring>
#define nn 1000010
using namespace std;
int siz;
int sread(char *s) {
	int siz = 0;
	char c = getchar();
	if(c == '.')return -1;
	while(!((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')))
		if((c = getchar()) == '.')	return -1;
		
	while((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9'))
		s[++siz] = c , c = getchar();
	return siz;
}
char s[nn];
int nxt[nn];
int main() {
	while(true) {
		memset(s , 0 , sizeof(s));
		memset(nxt , 0 , sizeof(nxt));
		siz = sread(s);
		if(siz == -1)	break;
		nxt[1] = 0;
		for(int i = 2 , j = 0 ; i <= siz ; i++) {
			while(s[i] != s[j + 1] && j != 0)
				j = nxt[j];
			if(s[i] == s[j + 1])
				++j;
			nxt[i] = j;
		}
		if(siz % (siz - nxt[siz]) == 0)
			printf("%d
" , siz / (siz - nxt[siz]));
		else
			printf("1
");
	}
	return 0;
}

C. 【例题3】周期长度和

题目

传送门

思路&代码

以前写过,传送门

题目

传送门

这题意不是一般人能读懂的,为了读懂题目,我还特意去翻了题解[手动笑哭]

题目大意:

给定一个字符串s

对于(s)的每一个前缀子串(s1),规定一个字符串(Q),(Q)满足:(Q)(s1)的前缀子串且(Q)不等于(s1)(s1)是字符串(Q+Q)的前缀.设(siz)为所有满足条件的(Q)(Q)的最大长度(注意这里仅仅针对(s1)而不是(s),即一个(siz)的值对应一个(s1))

求出所有(siz)的和

不要被这句话误导了:

求给定字符串所有前缀的最大周期长度之和

正确断句:求给定字符串 所有/前缀的最大周期长度/之和

我就想了半天:既然是"最大周期长度",那不是唯一的吗?为什么还要求和呢?

思路

其实这题要AC并不难(看通过率就知道)

看图

要满足(Q)(s1)的前缀,则(Q)(1)~(5)位和(s1)的1~5位是一样的,又因为(s1)(Q+Q)的前缀,所以又要满足(s1)的6~8位和(Q+Q)的6~8位一样,即(s1)的6~8位和Q的1~3位相等,回到(s1),标蓝色的两个位置相等.

回顾下KMP中(next)数组的定义:next[i]表示对于某个字符串a,"a中长度为next[i]的前缀子串"与"a中以第i为结尾,长度为next[i]的非前缀子串"相等,且next[i]取最大值

是不是悟到了什么,是不是感觉这题和(next)数组冥冥之中有某种相似之处?

但是,这仅仅只是开始

按照题目的意思,我们要让(Q)的长度最大,也就是图中蓝色部分长度最小,但是(next)中存的是蓝色部分的最大值,显然,两者相违背,难道我们要改造(next)数组吗?明显不行,若(next)存储的改为最小值,则原来求(next)的方法行不通.考虑换一种思路(一定要对KMP中(next)的求法理解透彻,不然下面看不懂,不行的复习一下),我们知道对于next[i],next[next[i-1]],next[next[next[i]]]...都能满足"前缀等于以(i)结尾的子串"这个条件,且越往后,值越小,所以,我们的目标就定在上面序列中从后往前第一个不为0的(next)

极端条件下,暴力跑可以去到(O(n^2)),理论上会超时(我没试过)

两种优化:

  1. 记忆化,时间效率应该是O(n)这里不详细讲,可以去到洛谷题解查看
  2. 倍增(我第一时间想到并AC的做法):
    我们将j=next[j]这一语句称作"j跳了一次"(感觉怪怪的),将next拓展为2维,next[i][k]表示结尾为i,j跳了2^k的前缀字符长度(也就是next[i][0]等价于原来的next[i])
    借助倍增LCA的思想(没学没关系,现学现用),这里不做赘述,上代码
		int tmp = i;
		for(rr int j = siz[i] ; j >= 0 ; --j)//siz[i]是next[i][j]中第一个为0的小标j,注意倒序枚举
			if(next[tmp][j] != 0)//如果不为0则跳
				tmp = next[tmp][j];

倍增方法在字符串长度去到(10^6)时是非常危险的,带个(log)理论是(2cdot 10^7)左右,常数再大那么一丢丢就TLE了,还好数据比较水,但是作为倍增和KMP的练习做一下也是不错的

最后,记得开longlong(不然我就一次AC了)

完整代码

#include <iostream>
#include <cmath>
#include <cstdio>
#define nn 1000010
#define rr register
#define ll long long
using namespace std;
int next[nn][30] ;
int siz[nn];
char s[nn];
int n;
int main() {
//	freopen("P3435_3.in" , "r" , stdin);
	cin >> n;
	do
		s[1] = getchar();
	while(s[1] < 'a' || s[1] > 'z');
	for(rr int i = 2 ; i <= n ; i++)
		s[i] = getchar();
	
	next[1][0] = 0;
	for(rr int i = 2 , j = 0 ; i <= n ; i++) {
		while(j != 0 && s[i] != s[j + 1])
			j = next[j][0];
		if(s[j + 1] == s[i])
			++j;
		next[i][0] = j;
	}
	
	rr int k = log(n) / log(2) + 1;
	for(rr int j = 1 ; j <= k ; j++)
		for(rr int i = 1 ; i <= n ; i++) {
			next[i][j] = next[next[i][j - 1]][j - 1];
			if(next[i][j] == 0)
				siz[i] = j;
		}
	ll ans = 0;
	for(rr int i = 1 ; i <= n ; ++i) {
		int tmp = i;
		for(rr int j = siz[i] ; j >= 0 ; --j)
			if(next[tmp][j] != 0)
				tmp = next[tmp][j];
		if(2 * (i - tmp) >= i && tmp != i)
			ans += (ll)i - tmp;
	}
	cout << ans;
	return 0;
} 

D. 【例题4】子串拆分

题目

思路

说明,以下思路时间大致复杂度为(O(n^2 )),最坏可以去到(O(n^3)),但数据较水可以通过,看了书,上面的解法也是(O(n^2)),对于(1leq |S|leq 1.5×10^4)来说已经是很极限了

其实思路很简单,我们直接枚举子串的左右边界(L,R),在右边界扩张的同时把新加入的字符的(nxt)求出来.至此,我们得到了子串(c),和(c)(nxt)数组,时间复杂度为(O(n^2))
那么我们如何判断(c)是否符合(c=A+B+C(kle len(A),1le len(B) ))呢?看代码(其实做了B,C题这里很好理解)

			int p = nxt[m];//m为c数组的长度,p即是可能的A的长度
			while(p >= k && p > 0) {
				if(m - p - p >= 1) {
					++ans;
					break;//直接退出,优化
				}
				p = nxt[p];
			}

这个判断的复杂度是可以达到(O(n))的,在数据范围下十分危险

下面看下书中是怎么说的:

我还以为书里有严格(O(n^2))的做法

下面(p_i)(nxt_i)意义相同

考虑没枚举左端点,假设左端点为(l),(A=S[l,|S|]),那么对字符串(A)跑一次KMP,在匹配的过程中,设匹配到第(i)个位置,那么我们就要考虑当前得出的(j),显然(A[1,j]=A[i-j+1,i]).如果(ile 2cdot j),那么令(j=p_j),此时(A[i,j]=A[i-j+1,i]),(j)沿指针(p)不断回跳,直到(2cdot j<i).然后判断(j)是否大于(k),如果是,那么累加答案.

因为每次KMP的复杂度是(O(n)),所以总时间复杂度为(O(n^2))

核心代码

//每次KMP匹配 
inline void solve(char *a) {
	p[1] = 0;
	int n = strlen(a + 1);
	for(int i = 1 , j = 0 ; i < n ; i++) {
		while(j && a[j + 1] != a[i + 1])
			j = p[j];
		if(a[j + 1] == a[i + 1])
			++j;
		p[i + 1] = j;
	}
	for(int i = 1 , j = 0 ; i < n ; i++) {
		while(j && a[j + 1] != a[i + 1])
			j = p[j];
		if(a[j + 1] == a[i + 1])
			j++;
		while(j * 2 >= i + 1)
			j = p[j];
		if(j >= k)
			++ans;
	}
}
//枚举左端点 
int len = strlen(str + 1) - (k << 1);
for(int i = 0 ; i < len ; i++)
	solve(str + i);

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#define nn 15000
using namespace std;
int sread(char *s) {
	int siz = 0;
	char c = getchar();
	if(c == '.')return -1;
	while(!((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')))
		if((c = getchar()) == '.')	return -1;
		
	while((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9'))
		s[++siz] = c , c = getchar();
	return siz;
}
int n , k , m;
int ans;
int nxt[nn];
char c[nn] , s[nn];

int main() {
	n = sread(s);
	cin >> k;
	for(int L = 1 ; L <= n ; L++) {
		memset(nxt , 0 , sizeof(nxt));
		memset(c , 0 , sizeof(c));
		for(int i = L ; i <= L + k + k ; i++)
			c[i - L + 1] = s[i];
		m = k + k;
		
		nxt[1] = 0;
		for(int i = 2 ; i <= m ; i++) {
			int j = nxt[i - 1];
			while(c[j + 1] != c[i] && j != 0)
				j = nxt[j];
			if(c[j + 1] == c[i])	++j;
			nxt[i] = j;
		}
		
		for(int R = L + k + k; R <= n ; R++) {
			m = R - L + 1;
			c[m] = s[R];
			
			int j = nxt[m - 1];
			while(c[j + 1] != c[m] && j != 0)
				j = nxt[j];
			if(c[j + 1] == c[m])	++j;
			if(m != 1)	nxt[m] = j;
					
			int p = nxt[m];
			while(p >= k && p > 0) {
				if(m - p - p >= 1) {
					++ans;
					break;
				}
				p = nxt[p];
			}
			
		}
	}
	cout << ans;
	return 0;
}
原文地址:https://www.cnblogs.com/dream1024/p/14612598.html