[HIHO1260]String Problem I(trie树)

题目链接:http://hihocoder.com/problemset/problem/1260

n个字符串,m次询问。每次询问给一个字符串,问这个字符串仅可以在一个地方加一个字母。这样操作后与n个字符串中有多少个字符串一样。

trie树维护n个字符串,然后从根节点向下dfs。

 1 #include <algorithm>
 2 #include <iostream>
 3 #include <iomanip>
 4 #include <cstring>
 5 #include <climits>
 6 #include <complex>
 7 #include <fstream>
 8 #include <cassert>
 9 #include <cstdio>
10 #include <bitset>
11 #include <vector>
12 #include <deque>
13 #include <queue>
14 #include <stack>
15 #include <ctime>
16 #include <set>
17 #include <map>
18 #include <cmath>
19 
20 using namespace std;
21 
22 typedef struct Node {
23     Node *next[26];
24     int cnt;
25     Node() {
26         cnt = 0;
27         for(int i = 0; i < 26; i++) {
28             next[i] = NULL;
29         }
30     }
31 }Node;
32 
33 void insert(Node *p, char *str) {
34     for(int i = 0; str[i]; i++) {
35         int t = str[i] - 'a';
36         if(p->next[t] == NULL) {
37             p->next[t] = new Node();
38         }
39         p = p->next[t];
40     }
41     p->cnt++;
42 }
43 
44 int len, cnt;
45 
46 void dfs(Node *p, char *str, int cur, int flag) {
47     if(flag > 1) return;
48     if(cur == len && flag == 1) {
49         cnt++;
50         return;
51     }
52     for(int i = 0; i < 26; i++) {
53         if(p->next[i]) {
54             // printf("%c
", 'a'+i);
55             if('a' + i == str[cur]) {
56                 dfs(p->next[i], str, cur+1, flag);
57             }
58             else {
59                 if(flag > 1) continue;
60                 dfs(p->next[i], str, cur, flag+1);
61             }
62         }
63     }
64 }
65 
66 void del(Node *root) {
67     for(int i = 0; i < 26; i++) {
68         if(root->next[i] != NULL) {
69             del(root->next[i]);
70         }
71     }
72     delete root;
73 }
74 
75 const int maxn = 100010;
76 int n, m;
77 char tmp[maxn];
78 
79 int main() {
80     // freopen("in", "r", stdin);
81     while(~scanf("%d %d", &n, &m)) {
82         Node *root = new Node();
83         for(int i = 0; i < n; i++) {
84             scanf("%s", tmp);
85             insert(root, tmp);
86         }
87         for(int i = 0; i < m; i++) {
88             scanf("%s", tmp);
89             len = strlen(tmp);
90             cnt = 0;
91             dfs(root, tmp, 0, 0);
92             printf("%d
", cnt);
93         }
94         del(root);
95     }
96     return 0;
97 }
原文地址:https://www.cnblogs.com/kirai/p/5469016.html