【SDOI 2011】染色

【题目链接】

            点击打开链接

【算法】

          树链剖分

【代码】

         本题,笔者求最近公共祖先并没有用树链剖分“往上跳”的方式,而是用倍增法。笔者认为这样比较好写,代码可读性

         比较高

         此外,笔者的线段树并没有用懒惰标记,只要当前访问节点的线段总数为1,那么就下传

         

#include<bits/stdc++.h>
using namespace std;
#define MAXLOG 18
const int MAXN = 1e5 + 10;

int i,n,m,timer,x,y,c,t;
int dep[MAXN],fa[MAXN],size[MAXN],son[MAXN],
        dfn[MAXN],top[MAXN],val[MAXN],pos[MAXN],anc[MAXN][MAXLOG];
vector<int> e[MAXN];
char opt[10];

struct SegmentTree {
        struct Node {
                int l,r,sum,lcover,rcover;
        } Tree[MAXN*4];
        inline void push_up(int index) {
                Tree[index].lcover = Tree[index<<1].lcover;
                Tree[index].rcover = Tree[index<<1|1].rcover;
                Tree[index].sum = Tree[index<<1].sum + Tree[index<<1|1].sum;
                if (Tree[index<<1].rcover == Tree[index<<1|1].lcover) Tree[index].sum--;
        }
        inline void push_down(int index) {
                Tree[index<<1].sum = Tree[index<<1|1].sum = 1;
                Tree[index<<1].lcover = Tree[index<<1].rcover = Tree[index].lcover;
                Tree[index<<1|1].lcover = Tree[index<<1|1].rcover = Tree[index].rcover;         
        }
        inline void build(int index,int l,int r) {
                int mid;
                Tree[index].l = l;
                Tree[index].r = r;
                if (l == r) {
                        Tree[index].lcover = Tree[index].rcover = val[pos[l]];
                        Tree[index].sum = 1;
                        return;
                }
                mid = (l + r) >> 1;
                build(index<<1,l,mid);
                build(index<<1|1,mid+1,r);
                push_up(index);
        }
        inline void modify(int index,int l,int r,int val) {
                int mid;
                if (Tree[index].l == l && Tree[index].r == r) {
                        Tree[index].lcover = Tree[index].rcover = val;
                        Tree[index].sum = 1;
                        return;
                }
                if (Tree[index].sum == 1) push_down(index);
                mid = (Tree[index].l + Tree[index].r) >> 1;
                if (mid >= r) modify(index<<1,l,r,val);
                else if (mid + 1 <= l) modify(index<<1|1,l,r,val);
                else {
                        modify(index<<1,l,mid,val);
                        modify(index<<1|1,mid+1,r,val);
                }
                push_up(index);
        } 
        inline int query(int index,int l,int r) {
                int mid,t;
                if (Tree[index].l == l && Tree[index].r == r) return Tree[index].sum;
                if (Tree[index].sum == 1) push_down(index);
                mid = (Tree[index].l + Tree[index].r) >> 1;
                if (mid >= r) return query(index<<1,l,r);
                else if (mid + 1 <= l) return query(index<<1|1,l,r);
                else {
                        t = 0;
                        if (Tree[index<<1].rcover == Tree[index<<1|1].lcover) t = 1;
                        return query(index<<1,l,mid) + query(index<<1|1,mid+1,r) - t;
                }
        }
        inline int get(int index,int pos) {
                int mid;
                if (Tree[index].l == Tree[index].r) return Tree[index].lcover;
                if (Tree[index].sum == 1) push_down(index);
                mid = (Tree[index].l + Tree[index].r) >> 1;
                if (mid >= pos) return get(index<<1,pos);
                else return get(index<<1|1,pos);
        }
} T;
inline void dfs1(int x) {
        int i,y;
        anc[x][0] = fa[x];
        for (i = 1; i < MAXLOG; i++) {
                if (dep[x] < (1 << i)) break; 
                anc[x][i] = anc[anc[x][i-1]][i-1];
        }
        size[x] = 1;  
        for (i = 0; i < e[x].size(); i++) {
                y = e[x][i];
                if (fa[x] != y) {
                        dep[y] = dep[x] + 1;
                        fa[y] = x;
                        dfs1(y);
                        size[x] += size[y];
                        if (size[y] > size[son[x]]) son[x] = y;
                }
        }
}
inline void dfs2(int x,int tp) {
        int i,y;
        dfn[x] = ++timer; 
        pos[timer] = x;
        top[x] = tp;
        if (son[x]) dfs2(son[x],tp);
        for (i = 0; i < e[x].size(); i++) {
                y = e[x][i];
                if (fa[x] != y && son[x] != y) 
                        dfs2(y,y);
        }     
}
inline int lca(int x,int y) {
        int i,t;
        if (dep[x] > dep[y]) swap(x,y);
        t = dep[y] - dep[x];
        for (i = 0; i <= MAXLOG - 1; i++) {
                if (t & (1 << i)) 
                    y = anc[y][i]; 
        }    
        if (x == y) return x;
        for (i = MAXLOG - 1; i >= 0; i--) {
                if (anc[x][i] != anc[y][i]) {
                        x = anc[x][i];
                        y = anc[y][i];
                }
        }    
        return anc[x][0];
}
inline void modify(int x,int y,int c) {
        int tx = top[x],
            ty = top[y];
        while (tx != ty) {
                T.modify(1,dfn[tx],dfn[x],c);
                x = fa[tx]; tx = top[x];        
        }
        T.modify(1,dfn[y],dfn[x],c);
}
inline int query(int x,int y) {
        int tx = top[x],
            ty = top[y],ans = 0;
        while (tx != ty) {
                ans += T.query(1,dfn[tx],dfn[x]);
                if (T.get(1,dfn[tx]) == T.get(1,dfn[fa[tx]])) ans--;
                x = fa[tx]; tx = top[x];
        }
        ans += T.query(1,dfn[y],dfn[x]);
        return ans;
}

int main() {
        
        scanf("%d%d",&n,&m);
        for (i = 1; i <= n; i++) scanf("%d",&val[i]);
        for (i = 1; i < n; i++) {
                scanf("%d%d",&x,&y);
                e[x].push_back(y);
                e[y].push_back(x);
        }
            
        dfs1(1);
        dfs2(1,1);
        T.build(1,1,timer);
        
        while (m--) {
                scanf("%s",opt);
                if (opt[0] == 'C') {
                        scanf("%d%d%d",&x,&y,&c);
                        t = lca(x,y);
                        modify(x,t,c); modify(y,t,c);
                } else {
                        scanf("%d%d",&x,&y);
                        t = lca(x,y);
                        printf("%d
",query(x,t)+query(y,t)-1);
                }
        }
    
    return 0;
}
原文地址:https://www.cnblogs.com/evenbao/p/9196352.html