Hackerrank

https://www.hackerrank.com/challenges/the-grid-search/forum

今天碰见这题,看见难度是Moderate,觉得应该能半小时内搞定。

读完题目发现是纯粹的一道子矩阵匹配问题,想想自己以前没做过,肯定能学到新算法,于是就开搞了。

于是上网搜到了Rabin-Karp算法,一种基于hashing的模式匹配算法。尽管连一维的我也没写过,但看了思想以后觉得推广到二维应该也不会很难。

于是有了以下代码,原理就是计算子矩阵的hash key。以hash key的比较代替了子矩阵的比较,这样可以首先排除掉hash key不相等的子矩阵。

对于hash key相等的,再用朴素方法判断子矩阵是否相等。

为什么最后还是要判断子矩阵是否相等呢?因为hash key可能存在碰撞,即使概率不大,为了保证正确性也需要进行检查。

学习Rabin-Karp算法的资料在此:

http://blog.sina.com.cn/s/blog_6a09b5a70100nhnr.html

思路虽简单,代码写起来却各种bug,最终我花了不下两个钟头才搞定。Hackerrank果然是给hacker玩的,我这水平在上面真是举步维艰。

不过,这么搞下来,倒是有了实实在在的收获,如果学习算法能一直保持这种节奏就好了。

下面是AC的代码,时空复杂度均为O(N ^ 2):

 1 # 2D Rabin-Karp Algorithm
 2 import re
 3 
 4 MOD = 10 ** 9 + 7
 5 
 6 def get2DMatrix(n, m):
 7     a = [[0 for j in xrange(m)] for i in xrange(n)]
 8     return a
 9     
10 def calcHash(a, nn, mm):
11     n = len(a)
12     m = len(a[0])
13     
14     b = 1
15     for i in xrange(mm):
16         b = b * 10 % MOD
17     b2 = 1
18     for i in xrange(nn):
19         b2 = b2 * b % MOD
20     
21     h = get2DMatrix(n, m)
22     for i in xrange(n):
23         val = 0
24         for j in xrange(m):
25             val = (val * 10 + a[i][j]) % MOD
26             if j >= mm:
27                 val = (val + a[i][j - mm] * (MOD - b)) % MOD
28             h[i][j] = val
29             
30     h2 = get2DMatrix(n, m)
31     h2[0] = h[0][:]
32     for i in xrange(1, n):
33         for j in xrange(m):
34             h2[i][j] = (h2[i - 1][j] * b + h[i][j]) % MOD
35             if i >= nn:
36                 h2[i][j] = (h2[i][j] + h[i - nn][j] * (MOD - b2)) % MOD
37     return h, h2
38 
39 def equal(a, p, ai, aj):
40     np = len(p)
41     mp = len(p[0])
42     for i in xrange(np):
43         for j in xrange(mp):
44             if a[ai + i][aj + j] != p[i][j]:
45                 return False
46     return True
47     
48 def solve():
49     na, ma = map(int, re.split('s+', raw_input().strip()))
50     a = []
51     for i in xrange(na):
52         a.append(map(int, list(raw_input().strip())))
53     np, mp = map(int, re.split('s+', raw_input().strip()))
54     p = []
55     for i in xrange(np):
56         p.append(map(int, list(raw_input().strip())))
57     ha, h2a = calcHash(a, np, mp)
58     hp, h2p = calcHash(p, np, mp)
59     
60     for i in xrange(np - 1, na):
61         for j in xrange(mp - 1, ma):
62             if h2a[i][j] != h2p[np - 1][mp - 1]:
63                 continue
64             if equal(a, p, i - np + 1, j - mp + 1):
65                 print('YES')
66                 return
67     print('NO')
68     
69 if __name__ == '__main__':
70     t = int(raw_input())
71     for ti in xrange(t):
72         solve()
73         
原文地址:https://www.cnblogs.com/zhuli19901106/p/4687979.html