可持久化Trie模板

如果你了解过 01 Trie 和 可持久化线段树(例如 :  主席树 )、那么就比较好去可持久化 Trie

可持久化 Trie 当 01 Trie 用的时候能很方便解决一些原本 01 Trie 不能解决的一些问题

01 Trie 的经典贪心算法可以在一个数集里面找出某个数和 X 异或的最值

但若数集不固定、变成了每次问询一段区间或者树上路径此时 01 Trie 便无法快速解决

这个时候需要使用可持久化的 Trie 来维护和进行查询操作、例如用前缀和建 Trie 就能方便查询某一区间的状况

可持久化 Trie 和主席树很类似,都是通过为每个前缀or路径等存储一颗 Trie

然后再通过减法的方式来达到某一区间或者某一历史版本的状态

这里只给出模板、关于这个算法的学习、推荐 ==> Click here

模板 :

#include<bits/stdc++.h>
using namespace std;
const int maxNode = 1e6 + 10;
const int maxn = 1e5 + 10;
int root[maxn];///每颗Trie的根节点编号
int sz[maxNode];///每个节点被添加or访问了多少次
int ch[maxNode][2];///静态指针、指向每个节点的01边指向的节点
int totNode = 0;///用于新开节点编号、多测问题别忘初始化

///静态开辟节点
int newNode()
{
    memset(ch[totNode], 0, sizeof(ch[totNode]));
    sz[totNode] = 0;
    return totNode++;
}

///F是将要被继承的树、C是当前新增的树、val就是即将被添加到C这棵树上的值
inline void Insert(int F, int C, int val)
{
    F = root[F], C = root[C];
    for(int i=15; i>=0; i--){
        int bit = (val>>i) & 1;
        if(!ch[C][bit]){
            ch[C][bit] = newNode();
            ch[C][!bit] = ch[F][!bit];
            sz[ ch[C][bit] ] = sz[ ch[F][bit] ];
        }
        C = ch[C][bit], F = ch[F][bit];
        sz[C]++;
    }
}

///查询函数可以说是很多变了
///可持久化Trie的查询并不是很模板的一个东西
///所以请务必理解可持久化Trie的原理再来运用这个模板
///以下的查询函数是 HDU 4757 的查询函数
int Query(int A, int B, int val)
{
    int lca = LCA(A, B);
    int lcaAns = arr[lca]^val;
    A = root[A], B = root[B], lca = root[lca];
    int ret = 0;
    for(int i=15; i>=0; i--){
        int bit = (val>>i) & 1;
        if(sz[ch[A][!bit]] + sz[ch[B][!bit]] - 2 * sz[ch[lca][!bit]] > 0){
            ret += 1<<i;
            A = ch[A][!bit];
            B = ch[B][!bit];
            lca = ch[lca][!bit];
        }else A = ch[A][bit], B = ch[B][bit], lca = ch[lca][bit];
    }

    return max(ret, lcaAns);
}

///这个查询函数则对应 BZOJ 3261
int query(int x, int y, int val)
{
    int ret = 0;
    for(int i=Bit; i>=0; i--){
        int c = (val>>i) & 1;
        if(sum[ch[y][!c]] - sum[ch[x][!c]] > 0)
            ret += (1<<i),
            y = ch[y][!c],
            x = ch[x][!c];
        else x = ch[x][c], y = ch[y][c];
    }
    return ret;
}
View Code

一些题目 :

BZOJ 3261 最大异或和

分析 : 每次更新都是从最后加一个数、而每次问询都是查询某一后缀的异或和

可持久化 01 Trie 能够很方便知道某一区间的状况、但关键是往区间里面装什么才能方便查询

注意到异或的自反性质、可以每次给每一个前缀以可持久化的方式建立一颗 Trie 

然后对于区间 (L, R) 在可持久化 01 Trie 内查询其两个区间与 (PreSum[N] xor x) 的异或最大值便是答案

因为如果某个下标假设为 idx ( L ≤ idx ≤ R ) 且 PreSum[idx] xor (PreSum[N] xor x) 有最大值

那么就说明了 idx 便是这个 p 、因为上面的异或表达式实际上 (1 ~ idx) 这段的异或和都被自反掉了

所以剩下的肯定就是后缀异或和了、所以利用前缀异或和建可持久化 Trie 即可。

#include<bits/stdc++.h>
#define LL long long
#define ULL unsigned long long
 
#define scs(i) scanf("%s", i)
#define sci(i) scanf("%d", &i)
#define scd(i) scanf("%lf", &i)
#define scl(i) scanf("%lld", &i)
#define scIl(i) scanf("%I64d", &i);
#define scii(i, j) scanf("%d %d", &i, &j)
#define scdd(i, j) scanf("%lf %lf", &i, &j)
#define scll(i, j) scanf("%lld %lld", &i, &j)
#define scIll(i, j) scanf("%I64d %I64d", &i, &j)
#define sciii(i, j, k) scanf("%d %d %d", &i, &j, &k)
#define scddd(i, j, k) scanf("%lf %lf %lf", &i, &j, &k)
#define sclll(i, j, k) scanf("%lld %lld %lld", &i, &j, &k)
#define scIlll(i, j, k) scanf("%I64d %I64d %I64d", &i, &j, &k)
 
#define lson l, m, rt<<1
#define rons m+1, r, rt<<1|1
#define lowbit(i) (i & (-i))
#define mem(i, j) memset(i, j, sizeof(i))
 
#define fir first
#define sec second
#define ins(i) insert(i)
#define pb(i) push_back(i)
#define pii pair<int, int>
#define mk(i, j) make_pair(i, j)
#define pll pair<long long, long long>
using namespace std;
const int maxn = (300000<<1) + 30;
const int Bit = 23;
int ch[maxn*24][2], sum[maxn*24], sz = 1;
int root[maxn];
 
int newNode()
{
    memset(ch[sz], 0, sizeof(ch[sz]));
    sum[sz] = 0;
    return sz++;
}
 
void Insert(int y, int x, int val)
{
    for(int i=Bit; i>=0; i--){
        int c = (val>>i) & 1;
        if(!ch[x][c]){
            ch[x][c] = newNode();
            ch[x][!c] = ch[y][!c];
            sum[ch[x][c]] = sum[ch[y][c]];
        }
        x = ch[x][c], y = ch[y][c];
        ++sum[x];
    }
}
 
int query(int x, int y, int val)
{
    int ret = 0;
    for(int i=Bit; i>=0; i--){
        int c = (val>>i) & 1;
        if(sum[ch[y][!c]] - sum[ch[x][!c]] > 0)
            ret += (1<<i),
            y = ch[y][!c],
            x = ch[x][!c];
        else x = ch[x][c], y = ch[y][c];
    }
    return ret;
}
 
int N, M, arr[maxn], PreSum[maxn];
int main(void)
{
    scii(N, M);
    arr[1] = 0, N++;
    for(int i=2; i<=N; i++) sci(arr[i]);
    for(int i=1; i<=N; i++) PreSum[i] = PreSum[i-1]^arr[i];
    for(int i=1; i<=N; i++){
        root[i] = newNode();
        Insert(root[i-1], root[i], PreSum[i]);
    }
 
    char ch[3];
    int l, r, x;
    while(M--){
        scs(ch);
        if(ch[0] == 'A'){
            N++;
            sci(arr[N]);
            PreSum[N] = PreSum[N-1]^arr[N];
            root[N] = newNode();
            Insert(root[N-1], root[N], PreSum[N]);
        }else{
            sciii(l, r, x);
            int ans = query(root[l-1], root[r], PreSum[N]^x);
            printf("%d
", ans);
        }
    }
    return 0;
}
View Code

HDU 4757 Tree

题意 : 给出一颗树、树上的节点都有权值、接下来给出若干个询问 (u、v、x)

问从 u 到 v 最短路径上哪个节点和 x 异或结果最大、输出这个结果

分析 : 和上一题有点类似、只不过这里的区间变成了 u 到 v 间的最短路径即树上路径

嘚想办法建出和上题一样类似 "前缀和" 的东西方便使用减法来查询历史版本

针对树上路径的自然联想到 LCA 如果先随意指定一个根节点、先变成一颗有根树

那么 u 到 v 的最短路径可以用 LCA 表示为 dist(u) + dist(v) - 2 * dist(LCA(u, v))

根据上面这条式子、可以得到一个启发、可持久化 Trie 也可以通过这种方法来得到

我们需要的树上路径 Trie 即可以理解为从 u 到 v 最短路径中所有节点的组成 Trie

那做法就是先指定一个根、然后从根开始DFS、每次遍历到一个新节点

便以可持久化的方式新建一颗 Trie 且是继承自其父亲节点的 Trie

那么在查询的时候、给出两个节点 u、v 就对于每一位就可以通过

sz(u) + sz(v) - 2 * sz(LCA(u, v)) 来判断某一位是否包含 0/1 从而进行贪心选择

最后得到异或最值、当然注意这样做会漏掉 LCA、最后只要和 LCA ^ x 取最大便是答案

#include<bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 10;
struct EDGE{ int v, nxt, w; }Edge[maxn<<1];
int Head[maxn], cnt;
int dep[maxn], maxDep;
int Fa[maxn][16];
int arr[maxn];
int n, m;

const int maxNode = 2e6 + 10;
int root[maxn], sz[maxNode], ch[maxNode][2];
int tot = 0;

inline void init_Graph()
{
    memset(Head, -1, sizeof(Head));
    memset(Fa, 0, sizeof(Fa));
    cnt = 0; maxDep = 0;
}

inline void AddEdge(int from, int to)
{
    Edge[cnt].v = to;
    Edge[cnt].nxt = Head[from];
    Head[from] = cnt++;
}

int newNode()
{
    memset(ch[tot], 0, sizeof(ch[tot]));
    sz[tot] = 0;
    return tot++;
}

inline void Insert(int F, int C, int val)
{
    F = root[F], C = root[C];
    for(int i=15; i>=0; i--){
        int bit = (val>>i) & 1;
        if(!ch[C][bit]){
            ch[C][bit] = newNode();
            ch[C][!bit] = ch[F][!bit];
            sz[ ch[C][bit] ] = sz[ ch[F][bit] ];
        }
        C = ch[C][bit], F = ch[F][bit];
        sz[C]++;
    }
}

void dfs(int v, int fa)
{
    root[v] = newNode();
    Insert(fa, v, arr[v]);
    if(Fa[v][0] != 0) maxDep = max(maxDep, dep[v] = dep[Fa[v][0]]+1);
    for(int i=Head[v]; i!=-1; i=Edge[i].nxt){
        int Eiv = Edge[i].v;
        if(Eiv == Fa[v][0]) continue;
        Fa[Eiv][0] = v;
        dfs(Eiv, v);
    }
}

inline void Doubling()
{
    int UP = (int)(log(maxDep)/log(2));
    for(int j=1; j<=UP; j++){
        for(int i=1; i<=n; i++){
            if(Fa[i][j-1] != 0)
                Fa[i][j] = Fa[Fa[i][j-1]][j-1];
        }
    }
}

int LCA(int u, int v)
{
    int UP = (int)(log(maxDep)/log(2));
    if(dep[u] < dep[v]) swap(u, v);
    for(int j=UP; j>=0; j--)
        if(Fa[u][j] != 0 && dep[Fa[u][j]] >= dep[v])
            u = Fa[u][j];

    if(u == v) return v;

    for(int j=UP; j>=0; j--){
        if(Fa[u][j] != Fa[v][j]){
            u = Fa[u][j];
            v = Fa[v][j];
        }
    }

    return Fa[u][0];
}

int Query(int A, int B, int val)
{
    int lca = LCA(A, B);
    int lcaAns = arr[lca]^val;
    A = root[A], B = root[B], lca = root[lca];
    int ret = 0;
    for(int i=15; i>=0; i--){
        int bit = (val>>i) & 1;
        if(sz[ch[A][!bit]] + sz[ch[B][!bit]] - 2 * sz[ch[lca][!bit]] > 0){
            ret += 1<<i;
            A = ch[A][!bit];
            B = ch[B][!bit];
            lca = ch[lca][!bit];
        }else A = ch[A][bit], B = ch[B][bit], lca = ch[lca][bit];
    }

    return max(ret, lcaAns);
}

int main(void)
{
    while(~scanf("%d %d", &n, &m)){
        for(int i=1; i<=n; i++) scanf("%d", &arr[i]);

        init_Graph();
        for(int i=1; i<n; i++){
            int u, v;
            scanf("%d %d", &u, &v);
            AddEdge(u, v);
            AddEdge(v, u);
        }

        tot = 0;
        dfs(1, 0);
        Doubling();

        while(m--){
            int u, v, x;
            scanf("%d %d %d", &u, &v, &x);
            printf("%d
", Query(u, v, x));
        }
    }
    return 0;
}
View Code

 

原文地址:https://www.cnblogs.com/qwertiLH/p/9141345.html