树的统计Count HYSBZ

树的统计Count

HYSBZ - 1036

高级数据结构p329

感叹一句分块大法吼啊...

  1 #include <cstdio>
  2 #include <cstring>
  3 #include <iostream>
  4 #include <cmath>
  5 #include <algorithm>
  6 using namespace std;
  7 const int MAXN = 310000;
  8 const int inf = 0x3f3f3f3f;
  9 
 10 //struct Edge{
 11 //    int v, nxt;
 12 //}e1[MAXN<<1], e2[MAXN<<1];
 13 //int head1[MAXN], head2[MAXN];
 14 //int cnt1, cnt2;
 15 //void init(int &cnt, int *&head){
 16 //    memset(head, -1, sizeof(head));
 17 //    cnt = 0;    
 18 //}
 19 //void add(int u, int v, int *&head, Edge *&e, int &cnt){
 20 //    e[cnt++] = Edge{v, head[u]};
 21 //    head[u] = cnt++;
 22 //}
 23 
 24 struct Edge{
 25     int cnt, head[MAXN<<1], nxt[MAXN<<1], to[MAXN<<1];
 26     void init(){
 27         memset(head, -1, sizeof(head));
 28         cnt = 0;
 29     }
 30     void add(int u, int v){
 31         to[cnt] = v;
 32         nxt[cnt] = head[u];
 33         head[u] = cnt++;
 34     }
 35 }e1, e2;
 36 int val[MAXN], dep[MAXN], f[MAXN], bkrt[MAXN];
 37 int sz[MAXN], limit;
 38 int VSum[MAXN], VMax[MAXN];
 39 
 40 void buildBlock(int u, int fa, int d){
 41     dep[u] = d;
 42     f[u] = fa;
 43     int curB = bkrt[u];
 44     for(int i = e1.head[u]; ~i; i = e1.nxt[i]){
 45         int v = e1.to[i];
 46         if(v != fa){
 47             if(sz[curB] + 1 < limit){
 48                 e2.add(u, v);
 49                 bkrt[v] = curB;
 50                 sz[curB]++;
 51             }
 52             buildBlock(v, u, d + 1);
 53         }
 54     }
 55 }
 56 
 57 void InitData(int u, int sumval, int maxval){
 58     sumval += val[u]; maxval = max(maxval, val[u]);
 59     VSum[u] = sumval; VMax[u] = maxval;
 60     for(int i = e2.head[u]; ~i; i = e2.nxt[i]){
 61         InitData(e2.to[i], sumval, maxval);
 62     }
 63 }
 64 void update(int u, int data){
 65     val[u] = data;
 66     if(u == bkrt[u]) InitData(u, 0, -inf);
 67     else InitData(u, VSum[f[u]], VMax[f[u]]);
 68 }
 69 pair<int, int> query(int a, int b){
 70     pair<int, int> ans(0, -inf);
 71     while(a != b){
 72         if(dep[a] < dep[b]) swap(a, b);
 73         if(bkrt[a] == bkrt[b]){
 74             ans.first += val[a];
 75             ans.second = max(ans.second, val[a]);
 76             a = f[a];
 77         }else{
 78             if(dep[bkrt[a]] < dep[bkrt[b]]) swap(a, b);
 79             ans.first += VSum[a];
 80             ans.second = max(ans.second, VMax[a]);
 81             a = f[bkrt[a]];
 82         }
 83     }
 84     ans.first += val[a];
 85     ans.second = max(ans.second, val[a]);
 86     return ans;
 87 }
 88 
 89 int main(){
 90     int n;
 91     int u, v;
 92     scanf("%d", &n);
 93     limit = sqrt(n) + 1;
 94     e1.init(); e2.init();
 95     for(int i = 1; i < n; i++){
 96         scanf("%d %d", &u, &v);
 97         e1.add(u, v);
 98         e1.add(v, u);
 99     }
100     for(int i = 1; i <= n; i++){
101         scanf("%d", &val[i]);
102         bkrt[i] = i;
103     }
104     memset(sz, 0, sizeof(sz));
105     buildBlock(1, 0, 0);
106     for(int i = 1; i <= n; i++){
107         if(bkrt[i] == i){
108             InitData(i, 0, -inf);
109         }
110     }
111     int m;
112     char op[20];
113     scanf("%d", &m);
114     for(int i = 0; i < m; i++){
115         scanf("%s %d %d", op, &u, &v);
116         pair<int, int> ans;
117         if(op[1] == 'M'){
118             ans = query(u, v);
119             printf("%d
", ans.second);
120         }else if(op[1] == 'S'){
121             ans = query(u, v);
122             printf("%d
", ans.first);
123         }else{
124             update(u, v);
125         }
126     }
127     return 0;
128 }
View Code

虽然还不是很理解,,先码下吧

原文地址:https://www.cnblogs.com/yijiull/p/8362066.html