KMP 学习笔记

KMP(Knuth-Morris-Pratt)是 OI 中常用的字符串匹配算法之一,它可以有效地利用失配信息来使得匹配全过程中不回溯,从而在线性时间内完成匹配。

原理

我们将待匹配的字符串称为主串,用来匹配的字符串称为模式串(模式串长度小于等于主串长度)。

设模式串 P"orzzorzo",主串 S"orzzorqworzzorzo",使用朴素算法进行匹配时("-" 表示匹配成功,"|" 表示在此字符失配):

orzzorqworzzorzo
------|
orzzorzo

首先,将两串对齐,从左到右匹配 PS 的每一位字符。当失配时,将模式串右移一位:

orzzorqworzzorzo
 |
 orzzorzo

此时发现第一位就失配了,还要继续右移……

更好的策略是当 'q' 失配时,直接对齐模式串开头 "or",不必回溯重新匹配:

orzzorqworzzorzo
    --|
    orzzorzo

而接下来这次失配后,本来需要将模式串与 'r' 对齐,但根据上面的思路将模式串直接与 'q' 对齐即可:

orzzorqworzzorzo
      |
      orzzorzo

利用对称的前后缀

通过上述例子,可以发现如果部分匹配的串有对称的前后缀,则我们可以直接将模式串部分匹配串的前缀与主串部分匹配串的后缀对齐,如:

orzzorqworzzorzo
------|
orzzorzo

例子中的部分匹配串为 "orzzor",有对称的前后缀 "or",则可以直接将部分匹配模式串的前缀 "or" 与部分匹配主串的后缀 "or" 对齐。

再来一个例子,模式串为 "qqwqqa",主串为 "qqwqqwqqa"

qqwqqwqqa
-----|
qqwqqa

此时的部分匹配串为 "qqwqq",它有两个对称的前后缀,分别是 "qq""q",如果以 "q" 对齐,可以得到:

qqwqqwqqa
    -|
    qqwqq

在模式串第二个 'q' 处失配后,继续匹配,最终结果是匹配失败。

然而,如果我们以 "qq" 对齐,则有:

qqwqqwqqa
   ------
   qqwqqa

结果是匹配成功。

(事实上我们也不会先对齐 "q",因为先对齐 "qq" 可以更快匹配完

这个例子告诉我们,当部分匹配串有多个对称前后缀时,需要选择最长的,以保证匹配结果的正确。

定义 nxt 数组,长度等于模式串长度,它的第 i 个成员代表以模式串前 i 个字符作为部分匹配串时,部分匹配串的最长对称前后缀长度(前缀末尾的位置)。

匹配步骤:

  1. 如果当前字符匹配(P[j + 1] == S[i + 1]),则继续匹配下一个字符(i++, j++);

  2. 如果当前字符失配(P[j + 1] != S[i + 1]),则直接将模式串右移到前缀与后缀对齐的位置(j = nxt[j])直到匹配或者找到了 0(移动模式串,使模式串部分匹配串的前缀与主串部分匹配串的后缀对齐)。

推导前缀指针

定义 nxt 数组,长度等于模式串长度,它的第 i 个成员代表以模式串前 i 个字符作为部分匹配串时,部分匹配串的最长对称前后缀长度(前缀末尾的位置)。

由定义可得:

i     | 1 | 2 | 3 | 4 | 5 | 6 | 7
P     | q | q | w | q | q | q | a
nxt   | 0 | 1 | 0 | 1 | 2 | 2 | 0

推导方法与 KMP 类似,用自己前缀子串和自己匹配:

  1. j = nxt[i];

  2. 如果当前字符和最长前缀下一个字符匹配(P[j + 1] == P[i + 1]),则 nxt[i + 1]nxt[i] + 1j + 1

  3. 如果当前字符和最长前缀下一个字符失配P[j + 1] != P[i + 1]),则继续对比第 i + 1 个字符与 nxt[nxt[i]] + 1 个字符,以此类推(即 j = nxt[j]继续对比 P[j + 1]P[i + 1]),一直向前找直到匹配或者找到了 0。

如模式串:agctagcagctagct

加粗的 'a' 与最后一个 't' 不匹配,此时向前找找到 "agctagc" 的最后一个 'c'对称位置的后一个字符,发现是 't',则找到前后的 "agct" 是一个对称的前后缀。

实现

const int MAXN = 1000000;

inline int kmp(char *s, char *p) { 
	// 下标从 1 开始
	int ls = strlen(s + 1), lp = strlen(p + 1);
	static int nxt[MAXN + 5];

	// 预处理 nxt
	nxt[1] = 0;
	for (int i = 2; i <= lp; i++) {
		int j = nxt[i - 1];
		while (j > 0 && p[j + 1] != p[i]) j = nxt[j];

		if (p[j + 1] == p[i]) nxt[i] = j + 1;
		else nxt[i] = 0;
	}

	int res = 0; // 匹配次数
	for (int i = 1, j = 0; i <= ls; i++) {
		while (j > 0 && p[j + 1] != s[i]) j = nxt[j];

		if (p[j + 1] == s[i]) j++;

		if (j == lp) {
			res++;
			j = nxt[j];
			// j = 0 // 如果不允许重叠匹配
		}
	}

	return res;
}

参考资料

(基本照搬x

原文地址:https://www.cnblogs.com/lcfsih/p/14352085.html