【ZJOI 2008】树的统计

【题目链接】

           点击打开链接

【算法】

         树链剖分模板题

【代码】

        

#include<bits/stdc++.h>
using namespace std;
#define MAXN 30000

struct SegmentTree {
    int l,r,maxn,sum;    
} tree[MAXN*3];

int i,N,Q,u,v,t,num;
int size[MAXN+10],fa[MAXN+10],son[MAXN+10],pos[MAXN+10],
    top[MAXN+10],dep[MAXN+10],id[MAXN+10],a[MAXN+10];
string opt;
vector<int> E[MAXN+10];

template <typename T> inline void read(T &x) {
    int f = 1; x = 0;
    char c = getchar();
    for (; !isdigit(c); c = getchar()) { if (c == '-') f = -f; }
    for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
    x *= f;
}

template <typename T> inline void write(T x) {
    if (x < 0) { putchar('-'); x = -x; }
    if (x > 9) write(x/10);
    putchar(x%10+'0');    
}

template <typename T> inline void writeln(T x) {
    write(x);
    puts("");    
}

inline void dfs1(int x) {
    int i,y;
    size[x] = 1; son[x] = 0;
    for (i = 0; i < E[x].size(); i++) {
        y = E[x][i];
        if (y == fa[x]) continue;
        dep[y] = dep[x] + 1;
        fa[y]= x;
        dfs1(y);
        size[x] += size[y];
        if (size[y] > size[son[x]]) son[x] = y;
    }
}

inline void dfs2(int x,int tp) {
    int i,y;
    id[x] = ++num; top[x] = tp; pos[num] = x; 
    if (son[x]) dfs2(son[x],tp);
    for (i = 0; i < E[x].size(); i++) {
        y = E[x][i];
        if ((y != fa[x]) && (y != son[x]))
            dfs2(y,y);
    }    
}

inline void build(int index,int l,int r) {
    int mid;
    tree[index].l = l; tree[index].r = r;
    if (l == r) {
        tree[index].sum = tree[index].maxn = a[pos[l]];
        return;
    } else {
        mid = (l + r) >> 1;
        build(index*2,l,mid); build(index*2+1,mid+1,r);
        tree[index].sum = tree[index*2].sum + tree[index*2+1].sum;
        tree[index].maxn = max(tree[index*2].maxn,tree[index*2+1].maxn);
    }
}

inline void modify(int index,int pos,int t) {
    int mid;
    if (tree[index].l == tree[index].r) {
        tree[index].sum = tree[index].maxn = t;
    } else {
        mid = (tree[index].l + tree[index].r) >> 1;
        if (mid >= pos) modify(index*2,pos,t);
        else modify(index*2+1,pos,t);
        tree[index].maxn = max(tree[index*2].maxn,tree[index*2+1].maxn);
        tree[index].sum = tree[index*2].sum + tree[index*2+1].sum;
    }
}

inline int query_max(int index,int l,int r) {
    int mid;
    if ((tree[index].l == l) && (tree[index].r == r)) return tree[index].maxn;
    else {
        mid = (tree[index].l + tree[index].r) >> 1;
        if (mid >= r) return query_max(index*2,l,r);
        else if (mid + 1 <= l) return query_max(index*2+1,l,r);
        else return max(query_max(index*2,l,mid),query_max(index*2+1,mid+1,r));
    }    
}

inline int query_sum(int index,int l,int r) {
    int mid;
    if ((tree[index].l == l) && (tree[index].r == r)) return tree[index].sum;
    else {
        mid = (tree[index].l + tree[index].r) >> 1;
        if (mid >= r) return query_sum(index*2,l,r);
        else if (mid + 1 <= l) return query_sum(index*2+1,l,r);
        else return query_sum(index*2,l,mid) + query_sum(index*2+1,mid+1,r);
    }    
}

inline int query1(int u,int v) {
    int tu = top[u],
        tv = top[v],
        ans = -2e9;
    while (tu != tv) {
        if (dep[tu] > dep[tv]) {
            swap(u,v);
            swap(tu,tv);
        }
        ans = max(ans,query_max(1,id[tv],id[v]));
        v = fa[tv]; tv = top[v];
    }    
    if (id[u] > id[v]) swap(u,v);
    ans = max(ans,query_max(1,id[u],id[v]));
    return ans;
}

inline int query2(int u,int v) {
    int tu = top[u],
        tv = top[v],
        ans = 0;
    while (tu != tv) {
        if (dep[tu] > dep[tv]) {
            swap(u,v);
            swap(tu,tv);
        }
        ans += query_sum(1,id[tv],id[v]);
        v = fa[tv]; tv = top[v];
    }    
    if (id[u] > id[v]) swap(u,v);
    ans += query_sum(1,id[u],id[v]);
    return ans;    
}

int main() {
    
    read(N);
    for (i = 1; i < N; i++) {
        read(u); read(v);
        E[u].push_back(v);
        E[v].push_back(u);    
    }
    dfs1(1);
    dfs2(1,1);    
    for (i = 1; i <= N; i++) read(a[i]);
    
    build(1,1,num);
    
    read(Q);
    
    while (Q--) {
        cin >> opt;
        if (opt == "CHANGE") {
            read(u); read(t);
            modify(1,id[u],t);
        } else if (opt == "QMAX") {
            read(u); read(v);
            writeln(query1(u,v));
        } else {
            read(u); read(v);
            writeln(query2(u,v));
        }    
    }
    
    return 0;
    
} 
原文地址:https://www.cnblogs.com/evenbao/p/9196436.html