hdu 3804 Query on a tree (树链剖分+线段树)

Problem - 3804

  很久之前就学习的树链剖分,一直不敢写,感觉是一种十分高级的数据结构。不过,经过一段时间对dfs和bfs的训练以后,开始感觉对树链剖分有感觉了。于是,我就赶紧查看回以前的树链剖分的相关资料,然后最后决定把这个入门的树链剖分给灭了!

  这题的题意是,给出一棵带有边权的树,询问给定的点到编号为1的点的路径之间不超过给定的值的最大边权是多少。

  将树按照重链和轻链划分以后,在重链上构建一棵线段树,然后对其进行维护。每次询问的时候就不停的找向根结点移动的路径,如果是在重链上就利用线段树快速的跳跃到链的顶端,否则就逐步移动。逐步移动的以及在重链上跳跃的整体时间复杂度是O(log n),所以理论上是不会超时的。

  不过,鉴于是第一次写,所以我先是在不同链上跳跃处理不当而TLE和WA,然后就是树的深度过大,只好开栈挂,最后就是我的数组开太大了,从而导致MLE。不过,排除万难以后,我的代码最终以4s+和32M的压边通过了!

52 SCAU_Lyon 4421MS 32728K 4799B C++ 2013-04-12 01:05:53

代码如下:

View Code
  1 #pragma comment(linker, "/STACK:102400000,102400000")
  2 
  3 #include <cstdio>
  4 #include <cstring>
  5 #include <iostream>
  6 #include <vector>
  7 #include <vector>
  8 #include <map>
  9 #include <algorithm>
 10 
 11 using namespace std;
 12 
 13 #define PB push_back
 14 #define MPR make_pair
 15 #define _clr(x) memset(x, 0, sizeof(x))
 16 #define FI first
 17 #define SE second
 18 #define ALL(x) (x).begin(), (x).end()
 19 #define SZ(x) ((int) (x).size())
 20 #define REP(i, n) for (int i = 0; i < (n); i++)
 21 #define REP_1(i, n) for (int i = 1; i <= (n); i++)
 22 
 23 typedef vector<int> VI;
 24 typedef pair<int, int> PII;
 25 typedef vector<PII> VPII;
 26 const int N = 111111;
 27 VI rel[N], seg[N << 2];
 28 int cnt[N], son[N], id[N], top[N], offset[N], fa[N], nid[N], len[N];
 29 map<int, int> val[N];
 30 
 31 void input(int n) {
 32     int x, y, w;
 33     REP_1(i, n) {
 34         rel[i].clear();
 35         val[i].clear();
 36         id[i] = top[i] = fa[i] = len[i] = 0;
 37     }
 38     REP(i, n - 1) {
 39         scanf("%d%d%d", &x, &y, &w);
 40         rel[x].PB(y);
 41         rel[y].PB(x);
 42         val[x][y] = w;
 43         val[y][x] = w;
 44     }
 45 }
 46 
 47 int curCnt;
 48 
 49 void dfs(int x) {
 50     id[x] = ++curCnt;
 51     son[x] = 0;
 52     cnt[x] = 1;
 53     REP(i, SZ(rel[x])) {
 54         int t = rel[x][i];
 55         if (id[t]) continue;
 56         fa[t] = x;
 57         dfs(t);
 58         cnt[x] += cnt[t];
 59         if (~son[x]) {
 60             if (cnt[son[x]] < cnt[t]) son[x] = t;
 61         } else {
 62             son[x] = t;
 63         }
 64     }
 65     len[x] = son[x] ? len[son[x]] + 1 : 0;
 66 }
 67 
 68 #define lson l, m, rt << 1, os
 69 #define rson m + 1, r, rt << 1 | 1, os
 70 
 71 void build(int l, int r, int rt, int os) {
 72     seg[rt + os].clear();
 73     if (l == r) return ;
 74     int m = (l + r) >> 1;
 75     build(lson);
 76     build(rson);
 77 }
 78 
 79 void insert(int x, int p, int l, int r, int rt, int os) {
 80     seg[rt + os].PB(x);
 81     if (l == r) return ;
 82     int m = (l + r) >> 1;
 83     if (p <= m) insert(x, p, lson);
 84     else insert(x, p, rson);
 85 }
 86 
 87 int query(int L, int R, int x, int l, int r, int rt, int os) {
 88     if (L <= l && r <= R) {
 89         int lb = upper_bound(ALL(seg[rt + os]), x) - seg[rt + os].begin() - 1;
 90 //        REP(i, SZ(seg[rt + os])) cout << seg[rt + os][i] << ' '; cout << endl;
 91 //        cout << "!! " << x << ' ' << lb << endl;
 92         if (lb < 0) return -1;
 93         return seg[rt + os][lb];
 94     }
 95     int m = (l + r) >> 1, ret = -1;
 96     if (L <= m) ret = max(ret, query(L, R, x, lson));
 97     if (m < R) ret = max(ret, query(L, R, x, rson));
 98     return ret;
 99 }
100 
101 bool vis[N];
102 
103 void getChain(int n) {
104     int offsetSum = 1;
105     REP_1(i, n) {
106         if (top[i]) continue;
107         int cur, low = i;
108         while (son[fa[low]] == low) low = fa[low];
109         if (len[low]) build(1, len[low], 1, offsetSum);
110         else {
111             top[low] = low;
112             offsetSum++;
113             continue;
114         }
115         cur = low;
116         int c = 0;
117         while (cur) {
118 //            cout << cur << endl;
119             offset[cur] = offsetSum;
120             len[cur] = len[low];
121             top[cur] = low;
122             nid[cur] = c++;
123             if (son[cur]) insert(val[cur][son[cur]], c, 1, len[low], 1, offsetSum);
124             cur = son[cur];
125         }
126         c--;
127         cur = low;
128         REP_1(i, c << 2) {
129             sort(ALL(seg[offsetSum + i]));
130         }
131         offsetSum += c + 1 << 2;
132     }
133     _clr(vis);
134 }
135 
136 void PRE(int n) {
137     curCnt = 0;
138     dfs(1);
139 //    REP_1(i, n) {
140 //        cout << id[i] << ' ' << son[i] << ' ' << cnt[i] << endl;
141 //    }
142     getChain(n);
143 //    REP_1(i, n) {
144 //        cout << i << ": " << top[i] << ' ' << nid[i] << ' ' << offset[i] << endl;
145 //    }
146 }
147 
148 int query(int x, int y) {
149     int ret = -1;
150     while (x != 1) {
151 //        cout << x << " ??? " << fa[x] << ' ' << top[x] << endl;
152         if (top[x] == x) {
153             int w = val[x][fa[x]];
154 //            cout << x << ' ' << w << endl;
155             if (w <= y) ret = max(ret, w);
156             x = fa[x];
157         } else {
158 //            cout << x << " !! " << nid[x] << ' ' << len[x] << endl;
159             ret = max(ret, query(1, nid[x], y, 1, len[x], 1, offset[x]));
160             x = top[x];
161         }
162 //        cout << "ret " << ret << ' ' << x << endl;
163     }
164     return ret;
165 }
166 
167 void work(int n) {
168     int x, y;
169     REP(i, n) {
170         scanf("%d%d", &x, &y);
171         printf("%d\n", query(x, y));
172     }
173 }
174 
175 int main() {
176 //    VI tmp;
177 //    tmp.clear();
178 //    tmp.PB(1), tmp.PB(2), tmp.PB(4), tmp.PB(5);
179 //    cout << ((int) (upper_bound(ALL(tmp), 0) - tmp.begin())) << endl;
180 
181 //    freopen("in", "r", stdin);
182     int T, n;
183     scanf("%d", &T);
184     while (T--) {
185         scanf("%d", &n);
186         input(n);
187         PRE(n);
188         scanf("%d", &n);
189         work(n);
190     }
191     return 0;
192 }

UPD:

  非递归写法,时间和空间稍微小了点。

View Code
  1 #include <cstdio>
  2 #include <cstring>
  3 #include <iostream>
  4 #include <vector>
  5 #include <algorithm>
  6 #include <queue>
  7 #include <stack>
  8 
  9 using namespace std;
 10 
 11 #define REP(i, n) for (int i = 0; i < (n); i++)
 12 #define REP_1(i, n) for (int i = 1; i <= (n); i++)
 13 #define FI first
 14 #define SE second
 15 #define PB push_back
 16 #define SZ(x) ((int) (x).size())
 17 #define MPR make_pair
 18 #define ALL(x) (x).begin(), (x).end()
 19 
 20 typedef vector<int> VI;
 21 const int N = 111111;
 22 
 23 int offset[N], chainLen[N], preferSon[N], pre[N], cnt[N], top[N], weight[N];
 24 bool vis[N];
 25 VI rel[N], val[N];
 26 
 27 void input(int n) {
 28     REP_1(i, n) {
 29         rel[i].clear();
 30         val[i].clear();
 31     }
 32     n--;
 33     int x, y, w;
 34     REP(i, n) {
 35         scanf("%d%d%d", &x, &y, &w);
 36         rel[x].PB(y);
 37         val[x].PB(w);
 38         rel[y].PB(x);
 39         val[y].PB(w);
 40     }
 41 }
 42 
 43 #define lson l, m, rt << 1, offs
 44 #define rson m + 1, r, rt << 1 | 1, offs
 45 VI seg[N << 2];
 46 
 47 void build(int l, int r, int rt, int offs) {
 48     seg[rt + offs].clear();
 49     if (l == r) return ;
 50     int m = (l + r) >> 1;
 51     build(lson);
 52     build(rson);
 53 }
 54 
 55 void update(int x, int p, int l, int r, int rt, int offs) {
 56     seg[rt + offs].PB(x);
 57     if (l == r) {
 58         return ;
 59     }
 60     int m = (l + r) >> 1;
 61     if (p <= m) update(x, p, lson);
 62     else update(x, p, rson);
 63 }
 64 
 65 void sortNode(int l, int r, int rt, int offs) {
 66     sort(ALL(seg[rt + offs]));
 67     seg[rt + offs].end() = unique(ALL(seg[rt + offs]));
 68     if (l == r) return ;
 69     int m = (l + r) >> 1;
 70     sortNode(lson);
 71     sortNode(rson);
 72 }
 73 
 74 int query(int L, int R, int x, int l, int r, int rt, int offs) {
 75     int ret = -1;
 76     if (L <= l && r <= R) {
 77         VI::iterator ii = upper_bound(ALL(seg[rt + offs]), x);
 78         if (ii == seg[rt + offs].begin()) return -1;
 79         ii--;
 80         return *ii;
 81     }
 82     int m = (l + r) >> 1;
 83     if (L <= m) ret = max(ret, query(L, R, x, lson));
 84     if (m < R) ret = max(ret, query(L, R, x, rson));
 85     return ret;
 86 }
 87 
 88 void BFS(int n) {
 89     REP_1(i, n) {
 90         preferSon[i] = pre[i] = 0;
 91         vis[i] = false;
 92     }
 93     stack<int> S;
 94     cnt[0] = preferSon[0] = top[0] = 0;
 95     while (!S.empty()) S.pop();
 96     S.push(1);
 97     while (!S.empty()) {
 98         int cur = S.top();
 99         S.pop();
100         if (vis[cur]) {
101             preferSon[cur] = 0;
102             cnt[cur] = 1;
103             int sz = SZ(rel[cur]);
104             REP(i, sz) {
105                 int t = rel[cur][i];
106                 if (pre[t] != cur) continue;
107                 cnt[cur] += cnt[t];
108                 if (cnt[preferSon[cur]] < cnt[t]) {
109                     preferSon[cur] = t;
110                 }
111             }
112             chainLen[cur] = preferSon[cur] ? chainLen[preferSon[cur]] + 1 : 0;
113             vis[cur] = false;
114         } else {
115             vis[cur] = true;
116             S.push(cur);
117             int sz = SZ(rel[cur]);
118             REP(i, sz) {
119                 int t = rel[cur][i];
120                 if (vis[t]) continue;
121                 pre[t] = cur;
122                 weight[t] = val[cur][i];
123                 S.push(t);
124             }
125         }
126     }
127 //    REP_1(i, n) {
128 //        cout << i << ": " << pre[i] << ' ' << cnt[i] << ' ' << preferSon[i] << ' ' << chainLen[i] << endl;
129 //    }
130     while (!S.empty()) S.pop();
131     int offsetSum = 1;
132     S.push(1);
133     vis[1] = true;
134 //    puts("here?");
135     while (!S.empty()) {
136         int cur = S.top();
137         S.pop();
138         if (preferSon[pre[cur]] != cur) {
139             offset[cur] = offsetSum;
140             top[cur] = cur;
141             if (chainLen[cur]) {
142                 build(1, chainLen[cur], 1, offsetSum);
143                 offsetSum += chainLen[cur] << 2;
144             } else {
145                 offsetSum++;
146             }
147         }
148 //        cout << cur << endl;
149         int sz = SZ(rel[cur]);
150         bool hasPrefer = false;
151         REP(i, sz) {
152             int t = rel[cur][i];
153             if (vis[t]) continue;
154             S.push(t);
155             vis[t] = true;
156             if (preferSon[cur] == t) {
157                 hasPrefer = true;
158                 offset[t] = offset[cur];
159                 top[t] = top[cur];
160                 update(val[cur][i], chainLen[top[t]] - chainLen[t], 1, chainLen[top[t]], 1, offset[t]);
161             }
162         }
163 //        cout << cur << ' ' << chainLen[top[cur]] << ' ' << offset[cur] << endl;
164         if (!hasPrefer && chainLen[top[cur]]) sortNode(1, chainLen[top[cur]], 1, offset[cur]);
165     }
166 //    REP_1(i, n) {
167 //        cout << i << ": " << offset[i] << ' ' << top[i] << endl;
168 //    }
169 //    puts("no!");
170 }
171 
172 void PRE(int n) {
173     BFS(n);
174 }
175 
176 int query(int x, int y) {
177     int ret = -1;
178     while (x != 1) {
179 //        cout << x << endl;
180         if (top[x] == x) {
181             if (weight[x] <= y) ret = max(ret, weight[x]);
182             x = pre[x];
183         } else {
184             ret = max(ret, query(1,  chainLen[top[x]] - chainLen[x], y, 1, chainLen[top[x]], 1, offset[x]));
185             x = top[x];
186         }
187 //        cout << "~~ " << ret << endl;
188     }
189     return ret;
190 }
191 
192 void work(int n) {
193     int x, y;
194     REP(i, n) {
195         scanf("%d%d", &x, &y);
196         printf("%d\n", query(x, y));
197     }
198 }
199 
200 int main() {
201 //    freopen("in", "r", stdin);
202     int T, n;
203     scanf("%d", &T);
204     while (T-- && ~scanf("%d", &n)) {
205         input(n);
206         PRE(n);
207         scanf("%d", &n);
208         work(n);
209 //        cout << "ok!!" << endl;
210     }
211     return 0;
212 }

——written by Lyon

原文地址:https://www.cnblogs.com/LyonLys/p/hdu_3804_Lyon.html