树上莫队 SPOJ COT2


  • 题意: 给一棵树,每次查询u到v路径上有多少不同的点权

首先需要证明这类题目符合区间加减性质

摘选一段vfk大牛的证明
用S(v, u)代表 v到u的路径上的结点的集合。
用root来代表根结点,用lca(v, u)来代表v、u的最近公共祖先。
那么S(v, u) = S(root, v) xor S(root, u) xor lca(v, u)
其中xor是集合的对称差。
简单来说就是节点出现两次消掉。
lca很讨厌,于是再定义
T(v, u) = S(root, v) xor S(root, u)
观察将curV移动到targetV前后T(curV, curU)变化:
T(curV, curU) = S(root, curV) xor S(root, curU)
T(targetV, curU) = S(root, targetV) xor S(root, curU)
取对称差:
T(curV, curU) xor T(targetV, curU)= (S(root, curV) xor S(root, curU)) xor (S(root, targetV) xor S(root, curU))
由于对称差的交换律、结合律:
T(curV, curU) xor T(targetV, curU)= S(root, curV) xorS(root, targetV)
两边同时xor T(curV, curU):
T(targetV, curU)= T(curV, curU) xor S(root, curV) xor S(root, targetV)
发现最后两项很爽……哇哈哈
T(targetV, curU)= T(curV, curU) xor T(curV, targetV)
(有公式恐惧症的不要走啊 T_T)
也就是说,更新的时候,xor T(curV, targetV)就行了。
即,对curV到targetV路径(除开lca(curV, targetV))上的结点,将它们的存在性取反即可

其实也就是说对于先前位置(preu,prev) 和当前位置 已知T(preu,prev) 可以很方便的计算T(curu,curv)
具体做法如下
求出T(preu,curu) 和 T(prev,curv) 暴力遍历两点之间路径即可
此时T(curu,curv) = T(preu,prev) xor T(preu,curu) xor T(prev,curv)
S(curu,curv) = T(curu,curv) xot lca(curu,curv)


下面考虑如何对树上序列进行分块了
有两种方法

  • 第一种方法:dfs的时候对每个点记录 进栈时间戳f(x) 和 出栈时间戳g(x),得到一个2n的序列
  • 对于查询(x,y) 令f(x) < f(y)
  • 如果 x 是 y 的祖先,考虑从x向下走到y 即区间[f(x) , f(y)]
    显然除了x到y路径上的点 之外 其他在区间[f(x),f(y)]出现的点都出现了两次
  • 如果x 不是 y 的祖先,那么必然是先往上走 再往下,即区间[g(x),f(y)] 再加上lca(x,y)
  • 第二种方法: 考虑对树上关键点的划分,详情见分块有关论文,证明我也没太看懂,大概的理解就是把一些距离相近的点划分成一块,减少块与块之间需要跨越的距离。

第一种方法 序列长度为2n 看起来常数似乎比第二种要大,而且每个点记录两次处理起来麻烦一点,所以我用的是第二种

#include<bits/stdc++.h>

using namespace std;

const int N = 1e5 + 10;
int col[N],vis[N],cnt[N];
int pos[N],dfn[N];
int head[N],EN,tot;
int n , m, Siz;
inline void read(int &x){
    char c = getchar();
    x = 0;
    while(c < '0' || c > '9') c = getchar();
    while(c >= '0' && c <= '9') x = x * 10 + c - 48,c = getchar();
}
struct Q{
    int l , r, id, b;
    int x , y;
    Q(){};
    bool operator < (const Q&rhs)const{
        if(b  == rhs.b) return dfn[r] < dfn[rhs.r];///左端点先按块排序,再右端点按时间戳排序
        return b < rhs.b;
    }
}q[N];

int ans[N] , res;

struct edge{
    int v,nxt;
    edge(){};
    edge(int v,int nxt):v(v),nxt(nxt){};
}e[N];

void add(int u,int v){
    e[EN] = edge(v,head[u]);
    head[u] = EN++;
}

int f[N][22],dep[N];
int lca(int u,int v){
    if(dep[u] < dep[v]) swap(u,v);
    int d = dep[u] - dep[v];
    for(int i = 0;i <= 20;i++)
        if((1<<i) & d) u = f[u][i];
    if(u == v) return u;
    for(int i = 20;i >= 0;i--)
        if(f[u][i] != f[v][i]) u = f[u][i],v = f[v][i];
    return f[u][0];
}
int stk[N],top,b_cnt;

int dfs(int u,int fa,int d){
    for(int i = 1;i <= 20;i++) f[u][i] = f[f[u][i-1]][i-1];
    dep[u] = d;
    dfn[u] = tot++;
    int siz = 0;
    for(int i = head[u];~i;i = e[i].nxt){
        if(i == fa) continue;
        int v = e[i].v;
        f[v][0] = u;
        siz += dfs(v,i ^ 1,d + 1);
        if(siz >= Siz){
            while(siz--) pos[stk[top--]] = b_cnt;
            b_cnt++;
        }
    }
    stk[++top] = u;
    return siz + 1;
}
void init(){
    memset(head,-1,sizeof(head));
    b_cnt = top = EN = tot = 0;
}
inline void up(int u){
    if(!vis[u]) {
        if(++cnt[col[u]] == 1) res++;
        vis[u] = 1;
    }else{
        vis[u] = 0;
        if(--cnt[col[u]] == 0) res--;
    }
}
inline void work(int u,int v){
    while(u != v){
        if(dep[u] < dep[v]) swap(u,v);
        up(u),u = f[u][0];
    }
}
map<int,int> mp;
int ID;
/*
8 2
1000000000000 2 9 3 8 5 1000001 1000001
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5
7 8
*/
int main(){


    read(n),read(m);
    ID = 1;mp.clear();
    for(int i = 1;i <= n;i++) {
            read(col[i]);
            if(!mp[col[i]]) mp[col[i]] = ID++;
            col[i] = mp[col[i]];
    }
    init();
    int rt = 0;
    for(int i = 1;i < n;i++){
        int u , v;
        read(u),read(v);
        if(!u) rt = v;
        else if(!v) rt = u;
        else add(u,v),add(v,u);
    }
    Siz = sqrt(n + 0.5);
    dfs(1,-1,1);
    while(top) pos[stk[top--]] = b_cnt;

    for(int i = 0;i < m;i++){
        read(q[i].l),read(q[i].r);//read(q[i].x),read(q[i].y);
        if(dfn[q[i].l] > dfn[q[i].r]) swap(q[i].l,q[i].r);
        q[i].id = i;
        q[i].b = pos[q[i].l];
    }
    sort(q,q + m);
    memset(vis,0,sizeof(vis));
    memset(cnt,0,sizeof(cnt));
    res = 0;
    int LCA = lca(q[0].l,q[0].r);
    work(q[0].l,q[0].r);
    up(LCA);
    ans[q[0].id] = res;
    up(LCA);
    for(int i = 1;i < m;i++){
        work(q[i-1].l,q[i].l);
        work(q[i-1].r,q[i].r);
        LCA = lca(q[i].l,q[i].r);
        up(LCA);
        ans[q[i].id]  = res;
        up(LCA);
    }
    for(int i = 0;i < m;i++) printf("%d
",ans[i]);
    return 0;
}
原文地址:https://www.cnblogs.com/jiachinzhao/p/6928488.html