Trie树(字典树)总结篇

作用:Trie是一种用于元素范围较小(如0/1,26个字母),常用于字符串前缀、异或值相关的

原理:前缀树,每个节点有固定的sigma个节点,同一层是元素们的同一pos。

实现:

非动态开点

leetcode 1707. 与数组中元素的最大异或值

思路:先排序,再只把小于等于limit的加入,再求该元素与数组元素的最大值。

class Solution {
public:
    #define maxnode 32*100000+10
    #define sigma 2
    struct Trie {
        int ch[maxnode][sigma];
        int value[maxnode]; //  叶子节点的值
        int cnt[maxnode];   // 经该节点的元素个数
        int sz = 1;

        void init(){
            memset(ch, 0 ,sizeof(ch));
            memset(cnt, 0, sizeof(cnt));
            memset(value, 0, sizeof(value));
        }
    
        void insert(int x){
            int u = 0;
            for(int i = 31; i >= 0; i--){
                int c = (x>>i)&1;
                if(!ch[u][c])  ch[u][c] = sz++;
                u = ch[u][c];
                cnt[u]++;
            }
            value[u] = x;
        }

        // 查询与x异或的最大值
        int query(int x){
            int u = 0, res = 0;
            for(int i = 31; i >= 0; --i){
                int a = (x>>i)&1;
                // cout << a << " " << b << endl;
                if(ch[u][a^1])  u = ch[u][a^1];
                else  u = ch[u][a];
            }
            return value[u]^x;
        }
    }trie;
    struct Node {
        int x, limit, id;
        Node(int x, int limit, int id) : x(x), limit(limit), id(id) {}
        bool operator < (const Node& node) {
            return this->limit < node.limit;
        }
    };

    vector<int> maximizeXor(vector<int>& nums, vector<vector<int>>& queries) {
        sort(nums.begin(), nums.end());
        vector<Node>myqueries;
        for(int i = 0;i < queries.size();i++) {
            myqueries.push_back(Node(queries[i][0], queries[i][1], i));
        }
        sort(myqueries.begin(), myqueries.end());
        trie.init();
        vector<int>res(queries.size());
        int pos = 0;  // nums数组的当前位置
        for(auto query : myqueries) {
            int x = query.x, limit = query.limit, id = query.id;
            // cout << x << " " << limit << " " << id << endl;
            while(pos < nums.size() && nums[pos] <= limit)  trie.insert(nums[pos++]);
            //  cout << "pos: " << pos << endl;
            if(pos == 0)  res[id] = -1;
            else  res[id] = trie.query(x);
           
            
        }
        return res;
    }
};

动态开点

leetcode211. 添加与搜索单词 - 数据结构设计

思路:动态开点,用节点指针。遇到 通配符‘.’ 进行dfs

class WordDictionary {
public:
    /** Initialize your data structure here. */
    struct Node {
        Node* son[26];
        bool is_end;
        Node(){
            for(int i = 0;i < 26;i++)  son[i] = NULL;
            is_end = false;
        }
    };
    // static Node* root;
    struct Trie {
        Node* root;
        void init(){
            root = new Node();
        }
    
        void insert(string str){
            Node* p = root;
            for(char mych : str){
                int c = mych-'a';
                if(!p->son[c])  p->son[c] = new Node();
                p = p->son[c];
            }
            // value[u] = str;
            p->is_end = true;
        }

        // 查询str是否存在
        bool query(string str, int pos, Node* p){
            // cout << str << " " << pos << " " << u << endl;
            if(pos == str.size())  return p->is_end;
            char mych = str[pos];
            if(mych == '.') {
                for(int i = 0;i < 26;i++) {
                    if(p->son[i] && query(str, pos+1, p->son[i]))  return true;
                }
                return false;
            } else {
                int c = mych-'a';
                if(!p->son[c])  return false;
                else {
                    return query(str, pos+1, p->son[c]);
                }
            }
        }
    }trie;

    WordDictionary() {
        trie.init();
    }
    
    void addWord(string word) {
        trie.insert(word);
    }
    
    bool search(string word) {
        return trie.query(word, 0, trie.root);
    }
};

其他的一些例子:

leetcode 745. 前缀和后缀搜索

思路:插入的时候加trick,并给每个节点更新权重,查询还是普通查询

class WordFilter {
public:

    struct Node {
        Node* son[27];
        int weight;
        Node(){
            for(int i = 0;i < 27;i++)  son[i] = NULL;
            weight = -1;
        }
    };
    // static Node* root;
    struct Trie {
        Node* root;
        void init(){
            root = new Node();
        }
    
        void insert(string str, int index){
            Node* p = root;
            for(char mych : str){
                int c;
                if(mych == '#')  c = 26;
                else  c = mych-'a';
                if(!p->son[c])  p->son[c] = new Node();
                p = p->son[c];
                p->weight = max(index, p->weight);  // 每个节点都要个更新
            }
            // value[u] = str;
        }

        // 查询str子串的权重
        int query(string str){
            int n = str.size();
            Node* p = root;
            for(int i = 0;i < n;i++) {
                int c;
                if(str[i] == '#')  c = 26;
                else  c = str[i]-'a';
                if(!p->son[c])  return -1;
                p = p->son[c];
            }
            return p->weight;
        }
    }trie;

    WordFilter(vector<string>& words) {
        trie.init();
        for(int i = 0;i < words.size();i++) {
            string suf = "";
            string word = words[i];
            int n = word.size();
            for(int j = 0;j <=n;j++) {
                // cout << suf+'#'+word  << endl;
                trie.insert(suf+'#'+word, i);
                if(j!=n) suf = word[n-1-j] + suf;
            }  
        }
    }
    
    int f(string prefix, string suffix) {
        string word = suffix + '#' + prefix;
        return trie.query(word);
    }

};

/**
 * Your WordFilter object will be instantiated and called as such:
 * WordFilter* obj = new WordFilter(words);
 * int param_1 = obj->f(prefix,suffix);
 */
View Code

leetcode 1032. 字符流

题意:最新查询的k的字符,是否能组成字典中的某个单词

思路:按后缀建树并查询,主要string类型的letter拼接会超时,要用char数组来做。

class StreamChecker {
public:
    struct Node {
        Node* son[26];
        bool is_end;
        Node() {
            for(int i = 0;i < 26;i++)  son[i] = NULL;
            is_end = false;
        }
    };
    struct Trie {
        Node* root = new Node();
        void insert(string str) {
            Node* p = root;
            for(char ch : str) {
                int c = ch-'a';
                if(!p->son[c])  p->son[c] = new Node();
                p = p->son[c];
            }
            p->is_end = true;
        }
        bool query(char* myletter, int cnt) {
            Node* p = root;
            for(int i = cnt-1;i >= 0;i--) {
                int c = myletter[i]-'a';
                if(!p->son[c])  break;
                p = p->son[c];
                if(p->is_end)  return true;   // 过程中有,也直接返回
            }
            return p->is_end;
        }
    }trie;
    // string myletter = "";
    char myletter[40010];
    int cnt = 0;

    StreamChecker(vector<string>& words) {
        for(string word : words) {
            reverse(word.begin(), word.end());
            trie.insert(word);
        }
        for(int i = 0;i < 40010;i++)  myletter[i] = '';
    }
    
    bool query(char letter) {
        myletter[cnt++] = letter;  // 用 string letters += letter; 会超时
        // cout << myletter  << endl;
        return trie.query(myletter, cnt);
    }
};

/**
 * Your StreamChecker object will be instantiated and called as such:
 * StreamChecker* obj = new StreamChecker(words);
 * bool param_1 = obj->query(letter);
 */
View Code

leetcode 336. 回文对

思路:考虑“abccc”、"ba" 和 "ab"、"cdcba",也就是说对于某个word,要么自己除去回文前缀再完全匹配;要么自己不变,匹配到的剩下部分是一个回文串。这要求我们下插入的时候保存一些信息。

ends记录以该节点结束的字符串,suffix记录从当前节点开始、剩下是回文串的字符串.

class Solution {
public:
    struct Node {
        Node* son[26];
        vector<int>ends;
        vector<int>suffix;
        Node(){
            for(int i = 0;i < 26;i++)  son[i] = NULL;
            ends.clear();
        }
    };
    struct Trie {
        Node* root = new Node();
        void insert(string str, int index) {
            Node* p = root;
            int n = str.size();
            for(int i = 0;i < n;i++) {
                int c = str[i]-'a';
                if(!p->son[c])  p->son[c] = new Node();
                p = p->son[c];
                if(i<n-1 && isPalindrome(str, i+1, n-1)) p->suffix.push_back(index);
                if(i == n-1)  p->suffix.push_back(index);
            }
            p->ends.push_back(index);
        }
        // 匹配结束的节点
        Node* query(string str, int pos) {
            Node* p = root;
            int n = str.size();
            for(int i = n-1;i >= pos;i--) {
                int c = str[i] - 'a';
                if(!p->son[c])  return NULL;
                p = p->son[c];
            }
            return p;
        }
    }trie;

    vector<vector<int>> palindromePairs(vector<string>& words) {
        vector<vector<int>>res;
        for(int i = 0;i < words.size();i++) {
            trie.insert(words[i], i);
        }
        for(int i = 0;i < words.size();i++) {
            string word = words[i];
            // word 为空
            if(word == "") {
                for(int j = 0;j < words.size();j++)
                    if(isPalindrome(words[j], 0, words[j].size()-1)) {
                        if(i!=j)  res.push_back({j, i});
                    }
                continue;
            }
            // 除去回文前缀
            int j = 0;
            for(;j < word.size();j++)
                if(isPalindrome(word, 0, j)) {
                    Node* p = trie.query(word, j+1);
                    if(p)  for(int num : p->ends) if(num!=i) res.push_back({num, i});
                }
            // 不除去回文前缀
            Node* p = trie.query(word, 0);
            if(p)  for(int num : p->suffix) if(num!=i) res.push_back({num, i});
        }
        return res;
    }
    static bool isPalindrome(string word, int start, int end) {
        // if(start > end)  return true;
        for(int i = 0;i <= end-start;i++)
            if(word[start+i] != word[end-i])  return false;
        return true;
    }
};
View Code

leetcode 面试题 17.17. 多次搜索

思路:将原串从每个位置开始插入一遍,记录起始位置,然后在用普通查找即可

这题用kmp也挺快吧

class Solution {
public:
    struct Node{
        Node* son[26];
        vector<int>ends;
        Node(){
            for(int i = 0;i < 26;i++)  son[i]=NULL;
            ends.clear();
        }
    };
    struct Trie {
        Node* root = new Node();
        void insert(string str, int pos){
            Node* p = root;
            for(int i = pos;i < str.size();i++) {
                int c = str[i]-'a';
                if(!p->son[c])  p->son[c] = new Node();
                p = p->son[c];
                p->ends.push_back(pos);
            }
        }
        vector<int> query(string str) {
            Node* p = root;
            for(char ch : str) {
                int c = ch-'a';
                if(!p->son[c])  return {};
                p = p->son[c];
            }
            return p->ends;
        }
    }trie;
    vector<vector<int>> multiSearch(string big, vector<string>& smalls) {
        for(int i = 0;i < big.size();i++)  trie.insert(big, i);
        vector<vector<int>>ans;
        for(string small : smalls) {
            ans.push_back(trie.query(small));
        }
        return ans;
    }
};
View Code

leetcode 247 连接词

思路:给定字符串去匹配,每次找到某个前缀,从当前位置接着递归找,像(["a", "aa", "aaa", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"]),会有大量重复,需要加上记忆化。

然而,,,加上记忆化还会TLE,明显每个位置只计算了一次,sum(words[i].length) <= 6 * 10^5 ,按道理不会超时呀。(有懂哥讲一下吗

class Solution {
public:
    struct Node {
        Node* son[26];
        bool is_end;
        Node(){
            for(int i = 0;i < 26;i++)  son[i] = NULL;
            is_end = false;
        }
    };
    struct Trie {
        Node* root = new Node();
        void insert(string str) {
            Node* p = root;
            for(char ch : str) {
                int c = ch-'a';
                if(!p->son[c])  p->son[c] = new Node();
                p = p->son[c];
            }
            p->is_end = true;
        }
        // 查询str是否存在
        int dp[10010];  // 记忆化
        void init() {
            for(int i = 0;i < 10010;i++)  dp[i] = -2;
        }
        int query(Node* p, string str, int pos) {  
            int& ans = dp[pos];
            if(ans != -2)  return ans;
            cout << pos << endl;
            int n = str.size();
            if(pos >= n)  return ans=0;
            int i = pos;
            for(;i < n;i++) {
                int c = str[i]-'a';
                if(!p->son[c])  return ans=-1;
                p = p->son[c];
                if(p->is_end){
                    int tmp = query(root, str, i+1);
                    if(tmp != -1) return ans=tmp+1;
                }
            }
            // if(p->is_end) {
            //     int tmp = query(root, str, i+1);
            //     if(tmp == -1)  return -1;
            //     else  return tmp+1;
            // } 
            // else  return -1;
            return ans=-1;
        }
    }trie;

    static bool stringCompare(const string& s1, const string& s2) {
        return s1.size() < s2.size();
    }

    vector<string> findAllConcatenatedWordsInADict(vector<string>& words) {
        sort(words.begin(), words.end(), stringCompare);
        auto iter = words.begin();
        if (*iter == "") {
            ++iter;
        }

        for(string word : words) {
            trie.insert(word);
        }
        vector<string>ans;
        for(string word : words) {
            // trie.init();
            int n = word.size();
            for(int i = 0;i <= n;i++)  trie.dp[i] = -2;
            if(trie.query(trie.root, word, 0) >= 2)  
                ans.push_back(word);
        }
        return ans;
    }
};
View Code

leetcode 1683 统计只差一个字符的子串数目

题意:s的一个子串,t的一个子串,两个子串只差一个字符,问这样的子串有多少对。

思路:看懂题意后,超暴力,详情看注释

class Solution {
public:
    struct Node{
        Node* son[26];
        int num_end;
        Node(){
            for(int i = 0;i < 26;i++)  son[i]=NULL;
            num_end = 0;
        }
    };
    struct Trie {
        Node* root = new Node();
        void insert(string str, int start, int end){
            Node* p = root;
            for(int i = start;i <= end;i++) {
                int c = str[i]-'a';
                if(!p->son[c])  p->son[c] = new Node();
                p = p->son[c];
            }
            p->num_end++;  // 该子串的次数
        }
        int query(Node* root, string str, int start, int end, int pos) {
            Node* p = root;
            for(int i = start;i <= end;i++) {
                if(i == pos) {   // 遇到通配符,递归搜索
                    int res = 0;
                    for(int j = 0;j < 26;j++)
                        if(j+'a' != str[pos] && p->son[j]) // 必须不同于该位原来的字符
                            res += query(p->son[j], str, pos+1, end, pos);
                    return res;
                }
                int c = str[i]-'a';
                if(c >= 26)  cout << str << " " << c << endl;
                if(!p->son[c])  return 0;
                p = p->son[c];
            }
            return p->num_end;
        }
    }trie;
    int countSubstrings(string s, string t) {
        int n = s.size(), m = t.size();
        for(int i = 0;i < m;i++)
            for(int j = i;j < m;j++) {
                trie.insert(t, i, j);  // 将t的所有子串插入
            }
        int ans = 0;
        for(int i = 0;i < n;i++)  // 枚举s的所以子串,再枚举通配符位置
            for(int j = i;j < n;j++) {
                for(int k = i;k <= j;k++)
                    ans += trie.query(trie.root, s, i, j, k); // k是通配符位置
            }
        return ans;
    }
};
View Code

leetcode 212. 单词搜索 II

思路:将所有word建树,从board的每个点出发dfs,如果当前字符串已经不是trie的前缀,停止回溯;否则继续。过程中遇到is_end,则加入答案。

class Solution {
public:
    struct Node{
        Node* son[26];
        bool is_end;
        Node(){
            for(int i = 0;i < 26;i++)  son[i]=NULL;
            is_end = false;
        }
    };
    struct Trie {
        Node* root = new Node();
        void insert(string str){
            Node* p = root;
            for(char ch : str) {
                int c = ch-'a';
                if(!p->son[c])  p->son[c] = new Node();
                p = p->son[c];
            }
            p->is_end = true;
        }
        int query(string str) {
            Node* p = root;
            for(char ch : str) {
                int c = ch-'a';
                if(!p->son[c])  return -2;
                p = p->son[c];
            }
            return p->is_end;
        }
    }trie;
    const int dx[4] = {-1, 0, 1, 0};
    const int dy[4] = {0, 1, 0, -1};
    bool vis[15][15];
    vector<string>ans;
    unordered_map<string, int>mp;
    void dfs(int x, int y, string& str, vector<vector<char>>& board) {
        // cout << str << endl;
        int tmp = trie.query(str);
        if(tmp == -2)  return;
        if(tmp == 1) {   // 不能return,可能有更长的
            if(mp[str] == 0) {
                ans.push_back(str); 
                mp[str]++;
            }
        }  
        int n = board.size(), m = board[0].size();
        for(int i = 0;i < 4;i++){
            int xx = x+dx[i], yy = y+dy[i];
            if(xx >= 0 && xx < n && yy >= 0 && yy < m && (!vis[xx][yy])) {
                vis[xx][yy] = true;
                str.push_back(board[xx][yy]);
                dfs(xx, yy, str, board);
                vis[xx][yy] = false;
                str.pop_back();
            }
        }
    }

    vector<string> findWords(vector<vector<char>>& board, vector<string>& words) {
        for(string word : words)  trie.insert(word);
        int n = board.size(), m = board[0].size();
        for(int i = 0;i < n;i++)
            for(int j = 0;j < m;j++) {
                memset(vis, false, sizeof(vis));
                string str(1, board[i][j]);
                vis[i][j] = true;
                dfs(i, j, str, board);
            }
        return ans;
    }
};
View Code

leetcode 648 单词替换
思路:dictionary构建字典树,sentence中的单词逐个查找即可,遇到is_end即可停止查找

class Solution {
public:
    struct Node {
        Node* son[26];
        bool is_end;
        Node(){
            for(int i = 0;i < 26;i++)  son[i] = NULL;
            is_end = false;
        }
    };
    struct Trie {
        Node* root = new Node();
        void insert(string str) {
            Node* p = root;
            for(char ch : str) {
                int c = ch-'a';
                if(!p->son[c])  p->son[c] = new Node();
                p = p->son[c];
            }
            p->is_end = true;
        }
        string query(string str) {
            Node* p = root;
            for(int i = 0;i < str.size();i++) {
                int c = str[i]-'a';
                if(!p->son[c])  return str;
                p = p->son[c];
                if(p->is_end)  return str.substr(0, i+1);
            }
            return str;
        }
    }trie;
    string sent[1010];
    string replaceWords(vector<string>& dictionary, string sentence) {
        for(string str : dictionary)  trie.insert(str);
        
        stringstream ss(sentence);
        int cnt = 0;
        while(ss >> sent[cnt++]);
        
        string ans = "";
        for(int i = 0;i < cnt;i++) {
            // cout << sent[i] << endl;
            sent[i] = trie.query(sent[i]);
            ans += sent[i];
            if(i != (cnt-1))  ans += " ";
            cout << sent[i] << endl;
        }   
        ans.pop_back();
        return ans;
    }
};
View Code

leetcode 676 实现一个魔法字典

思路:通配符匹配,类似于leetcode 1683

class MagicDictionary {
public:
    struct Node{
        Node* son[26];
        bool is_end;
        Node(){
            for(int i = 0;i < 26;i++)  son[i]=NULL;
            is_end = false;
        }
    };
    struct Trie {
        Node* root = new Node();
        void insert(string str){
            Node* p = root;
            for(char ch : str) {
                int c = ch-'a';
                if(!p->son[c])  p->son[c] = new Node();
                p = p->son[c];
            }
            p->is_end = true;
        }
        int query(Node* root, string str, int start, int end, int pos) {
            Node* p = root;
            for(int i = start;i <= end;i++) {
                if(i == pos) {   // 遇到通配符,递归搜索
                    int res = 0;
                    for(int j = 0;j < 26;j++)
                        if(j+'a' != str[pos] && p->son[j]) // 必须不同于该位原来的字符
                            if(query(p->son[j], str, pos+1, end, pos))  return true;
                    return false;
                }
                int c = str[i]-'a';
                if(c >= 26)  cout << str << " " << c << endl;
                if(!p->son[c])  return 0;
                p = p->son[c];
            }
            return p->is_end;
        }
    }trie;
    /** Initialize your data structure here. */
    MagicDictionary() {

    }
    
    void buildDict(vector<string> dictionary) {
        for(string str : dictionary)  trie.insert(str);
    }
    
    bool search(string searchWord) {
        int n = searchWord.size();
        for(int i = 0;i < n;i++) {
            if(trie.query(trie.root, searchWord, 0, n-1, i))  return true;
        }
        return false;
    }
};

/**
 * Your MagicDictionary object will be instantiated and called as such:
 * MagicDictionary* obj = new MagicDictionary();
 * obj->buildDict(dictionary);
 * bool param_2 = obj->search(searchWord);
 */
View Code

leetcode 677 键值映射

思路:记录末尾节点的val 和 中间节点的sum

class MapSum {
public:

    struct Node{
        Node* son[26];
        int num_end, sum;
        Node(){
            for(int i = 0;i < 26;i++)  son[i]=NULL;
            num_end = -1;   // 表示该key还不存在
            sum = 0;
        }
    };
    struct Trie {
        Node* root = new Node();
        void insert(string str, int val1, int val2){
            Node* p = root;
            for(char ch : str) {
                int c = ch-'a';
                if(!p->son[c])  p->son[c] = new Node();
                p = p->son[c];
                p->sum += val2-val1;
            }
            p->num_end = val2;
        }
        int query_sum(string str) {
            Node* p = root;
            for(char ch : str) {
                int c = ch-'a';
                if(!p->son[c])  return 0;
                p = p->son[c];
            }
            return p->sum;
        }
        int query_val(string str) {
            Node* p = root;
            for(char ch : str) {
                int c = ch-'a';
                if(!p->son[c])  return -1;
                p = p->son[c];
            }
            return p->num_end;
        }
    }trie;

    /** Initialize your data structure here. */
    MapSum() {

    }
    
    void insert(string key, int val) {
        int pre_val = trie.query_val(key);
        if(pre_val == -1)  trie.insert(key, 0, val);
        else  trie.insert(key, pre_val, val);
    }
    
    int sum(string prefix) {
        return trie.query_sum(prefix);
    }
};

/**
 * Your MapSum object will be instantiated and called as such:
 * MapSum* obj = new MapSum();
 * obj->insert(key,val);
 * int param_2 = obj->sum(prefix);
 */
View Code

leetcode 720 词典中最长的字符

class Solution {
public:
    struct Node{
        Node* son[26];
        bool is_end;
        Node(){
            for(int i = 0;i < 26;i++)  son[i]=NULL;
            is_end = false;
        }
    };
    struct Trie {
        Node* root = new Node();
        int insert(string str){
            Node* p = root;
            bool ret = true;
            int n = str.size();
            for(int i = 0;i < n;i++) {
                int c = str[i]-'a';
                if(!p->son[c])  p->son[c] = new Node();
                p = p->son[c];
                if((i != n-1) && (!p->is_end))  ret = false;
            }
            p->is_end = true;
            return ret;
        }

    }trie;
    bool strcmp(string a, string b) {
        if(a.size() == b.size())  return a > b;
        else  return a.size() < b.size();
    }
    string longestWord(vector<string>& words) {
        sort(words.begin(), words.end());
        string ans = "";
        // unordered_map<string, bool>mp;
        for(string word : words) {
            if(trie.insert(word)) {
                if(strcmp(ans, word))  ans = word;
            }
        }
        return ans;
    }
};
View Code

哈哈哈,除了连接词那题莫名的TLE,leetcode上的Trie被做完了。

个性签名:时间会解决一切
原文地址:https://www.cnblogs.com/lfri/p/14573095.html