倍增/线段树维护树的直径 hdu5993/2016icpc青岛L

题意:

给一棵树,每次询问删掉两条边,问剩下的三棵树的最大直径

点10W,询问10W,询问相互独立

Solution:

考虑线段树/倍增维护树的直径

考虑一个点集的区间 [l, r]

而我们知道了有 l <= k < r,

且知道 [l, k] 和 [k + 1, r] 两个区间的最长链的端点及长度

假设两个区间的直径端点分别为 (l1, r1) 和 (l2, r2)

那么 [l, r] 这个区间的直径长度为

dis(l1, r1) dis(l1, l1)  dis(l1, r2)

dis(r1, l2) dis(r1, r2) dis(l2, r2)

六个值中的最大值

本题因为操作子树,所以我们维护dfs序的区间最长链即可

证明:

首先有一个结论:

树上任意一个点在树中的最远点是树的直径的某个端点。我们可以用反证法轻易地证明这一点。

再扩展一下,有以下结论:树上任意一个点在树中的一个点集中的最远点是该点集中最长链的一个端点。

其实我们把点集等价地看为一棵虚树,然后就能用相似的证法解决了。  

代码:

  1 #include <stdio.h>
  2 #include <algorithm>
  3 
  4 using namespace std;
  5 
  6 const int N = 2e5 + 5;
  7 
  8 int T, n, m;
  9 
 10 int len, head[N], ST[20][N];
 11 
 12 struct edge{int u, v, w;}ee[N];
 13 
 14 int cnt, fa[N], log_2[N], st[N], en[N], dfn[N], dis[N], dep[N], pos[N];
 15 
 16 struct edges{int to, next, cost;}e[N];
 17 
 18 inline void add(int u, int v, int w) {
 19     e[++ len] = (edges){v, head[u], w}, head[u] = len;
 20     e[++ len] = (edges){u, head[v], w}, head[v] = len;
 21 }
 22 
 23 inline void dfs1(int u) {
 24     st[u] = ++ cnt, dfn[cnt] = u;
 25     for (int v, i = head[u]; i; i = e[i].next) {
 26         v = e[i].to;
 27         if (v == fa[u]) continue;
 28         fa[v] = u, dep[v] = dep[u] + 1;
 29         dis[v] = dis[u] + e[i].cost, dfs1(v);
 30     }
 31     en[u] = cnt;
 32 }
 33 
 34 inline void dfs2(int u) {
 35     dfn[++ cnt] = u, pos[u] = cnt;
 36     for (int v, i = head[u]; i; i = e[i].next) {
 37         v = e[i].to;
 38         if (v == fa[u]) continue;
 39         dfs2(v), dfn[++ cnt] = u;
 40     }
 41 }
 42 
 43 int mmin(int x, int y) {
 44     if (dep[x] < dep[y]) return x;
 45     return y;
 46 }
 47 
 48 inline int lca(int u, int v) {
 49     static int w;
 50     if (pos[u] > pos[v]) swap(u, v);
 51     w = log_2[pos[v] - pos[u] + 1];
 52     return mmin(ST[w][pos[u]], ST[w][pos[v] - (1 << w) + 1]);
 53 }
 54 
 55 inline int dist(int u, int v) {
 56     int Lca = lca(u, v);
 57     return dis[u] + dis[v] - dis[Lca] * 2;
 58 }
 59 
 60 inline void build() {
 61     for (int i = 1; i <= cnt; i ++)
 62         ST[0][i] = dfn[i];
 63     for (int i = 1; i < 20; i ++)
 64         for (int j = 1; j <= cnt; j ++)
 65             if (j + (1 << (i - 1)) > cnt) ST[i][j] = ST[i - 1][j];
 66             else ST[i][j] = mmin(ST[i - 1][j], ST[i - 1][j + (1 << (i - 1))]); 
 67 }
 68 
 69 int M;
 70 
 71 struct node {
 72     int l, r, dis;
 73 }tr[N << 1];
 74 
 75 inline void update(int o, int o1, int o2) {
 76     static int d;
 77     static node tmp;
 78     if (tr[o1].dis == -1) {tr[o] = tr[o2]; return;}
 79     if (tr[o2].dis == -1) {tr[o] = tr[o1]; return;}
 80     if (tr[o1].dis > tr[o2].dis) tmp = tr[o1];
 81     else tmp = tr[o2]; 
 82     d = dist(tr[o1].l, tr[o2].l);
 83     if (d > tmp.dis) tmp.l = tr[o1].l, tmp.r = tr[o2].l, tmp.dis = d;
 84     d = dist(tr[o1].l, tr[o2].r);
 85     if (d > tmp.dis) tmp.l = tr[o1].l, tmp.r = tr[o2].r, tmp.dis = d;
 86     d = dist(tr[o1].r, tr[o2].l);
 87     if (d > tmp.dis) tmp.l = tr[o1].r, tmp.r = tr[o2].l, tmp.dis = d;
 88     d = dist(tr[o1].r, tr[o2].r);
 89     if (d > tmp.dis) tmp.l = tr[o1].r, tmp.r = tr[o2].r, tmp.dis = d;
 90     tr[o] = tmp;
 91 }
 92 
 93 inline void ask(int s, int t) {
 94     if (s > t) return;
 95     for (s += M - 1, t += M + 1; s ^ t ^ 1; s >>= 1, t >>= 1) {
 96         if (~s&1) update(0, 0, s ^ 1);
 97         if ( t&1) update(0, 0, t ^ 1);
 98     }
 99 }
100 
101 inline int get_char() {
102     static const int SIZE = 1 << 23;
103     static char *T, *S = T, buf[SIZE];
104     if (S == T) {
105         T = fread(buf, 1, SIZE, stdin) + (S = buf);
106         if (S == T) return -1;
107     }
108     return *S ++;
109 }
110  
111 inline void in(int &x) {
112     static int ch;
113     while (ch = get_char(), ch > 57 || ch < 48);x = ch - 48;
114     while (ch = get_char(), ch > 47 && ch < 58) x = x * 10 + ch - 48;
115 }
116 
117 int main() {
118     int u, v, w, ans;
119     log_2[1] = 0;
120     for (int i = 2; i <= 200000; i ++) 
121         if (i == 1 << (log_2[i - 1] + 1))
122             log_2[i] = log_2[i - 1] + 1;
123         else log_2[i] = log_2[i - 1];
124     for (in(T); T --; ) {
125         in(n), in(m), cnt = len = 0;
126         for (int i = 1; i <= n; i ++)
127             head[i] = 0;
128         for (int i = 1; i < n; i ++) {
129             in(ee[i].u), in(ee[i].v), in(ee[i].w);
130             add(ee[i].u, ee[i].v, ee[i].w);
131         }
132         dfs1(1);
133         for (M = 1; M < n + 2; M <<= 1);
134         for (int i = 1; i <= n; i ++)
135             tr[i + M].l = tr[i + M].r = dfn[i], tr[i + M].dis = 0;
136         for (int i = n + M + 1; i <= (M << 1) + 1; i ++)
137             tr[i].dis = -1;
138         cnt = 0, dfs2(1), build();
139         for (int i = M; i; i --) 
140             update(i, i << 1, i << 1 | 1);
141         for (int i = 1; i < n; i ++) 
142             if (dep[ee[i].u] > dep[ee[i].v])
143                 swap(ee[i].u, ee[i].v);
144         for (int u, v, i = 1; i <= m; i ++) {
145             in(u), in(v), ans = 0;
146             u = ee[u].v, v = ee[v].v, w = lca(u, v);
147             if (w == u || w == v) {
148                 if (w != u) swap(u, v);
149                 tr[0].dis = -1, ask(1, st[u] - 1), ask(en[u] + 1, n), ans = max(ans, tr[0].dis);
150                 tr[0].dis = -1, ask(st[u], st[v] - 1), ask(en[v] + 1, en[u]), ans = max(ans, tr[0].dis);
151                 tr[0].dis = -1, ask(st[v], en[v]), ans = max(ans, tr[0].dis);
152             }
153             else {
154                 if (st[u] > st[v]) swap(u, v);
155                 tr[0].dis = -1, ask(1, st[u] - 1), ask(en[u] + 1, st[v] - 1), ask(en[v] + 1, n), ans = max(ans, tr[0].dis);
156                 tr[0].dis = -1, ask(st[u], en[u]), ans = max(ans, tr[0].dis);
157                 tr[0].dis = -1, ask(st[v], en[v]), ans = max(ans, tr[0].dis);
158             }
159             printf("%d
", ans);
160         }
161     }
162     return 0;
163 }
View Code

一开始没带脑子算错了复杂度,少算了个log开心的写了树剖LCA,还在dfs的时候求siz忘记把儿子的siz加上了

T到死...发现是带2个log,该死出题人多组数据不给数据组数,改写ST表O(1)求LCA,复杂度只带1个log过了

理论上线段树也可以用ST表代替,复杂度O(n)...当然不可能啦,预处理nlogn,回答O(1)

附加训练 51nod 1766

  1 #include <stdio.h>
  2 #include <algorithm>
  3 
  4 using namespace std;
  5 
  6 const int N = 1e5 + 5;
  7 
  8 int n, m, M, tot, head[N], st[18][N << 1], log_2[N << 1];
  9 
 10 int cnt, dis[N], dep[N], pos[N], dfn[N << 1];
 11 
 12 struct edge{int to, next, cost;}e[N << 1];
 13 
 14 int mmin(int x, int y) {
 15     return dep[x] < dep[y] ? x : y;
 16 }
 17 
 18 void add(int u, int v, int w) {
 19     e[++ tot] = (edge){v, head[u], w}, head[u] = tot;
 20     e[++ tot] = (edge){u, head[v], w}, head[v] = tot;
 21 }
 22 
 23 void dfs(int u, int fr) {
 24     dfn[++ cnt] = u, pos[u] = cnt;
 25     for (int v, i = head[u]; i; i = e[i].next) {
 26         v = e[i].to;
 27         if (v == fr) continue;
 28         dep[v] = dep[u] + 1, dis[v] = dis[u] + e[i].cost;
 29         dfs(v, u), dfn[++ cnt] = u;
 30     }
 31 }
 32 
 33 int lca(int u, int v) {
 34     if (pos[u] > pos[v]) swap(u, v);
 35     int w = log_2[pos[v] - pos[u] + 1];
 36     return mmin(st[w][pos[u]], st[w][pos[v] - (1 << w) + 1]);
 37 }
 38 
 39 int dist(int u, int v) {
 40     return dis[u] + dis[v] - dis[lca(u, v)] * 2;
 41 }
 42 
 43 struct node {
 44     int l, r, dis;
 45 
 46     node operator + (const node &a) const {
 47         node res;
 48         if (dis == -1) return a;
 49         if (a.dis == -1) return *this;
 50         if (dis > a.dis) res = *this;
 51         else res = a;
 52         int d = dist(l, a.l);
 53         if (d > res.dis) res.l = l, res.r = a.l, res.dis = d;
 54         d = dist(l, a.r);
 55         if (d > res.dis) res.l = l, res.r = a.r, res.dis = d;
 56         d = dist(r, a.l);
 57         if (d > res.dis) res.l = r, res.r = a.l, res.dis = d;
 58         d = dist(r, a.r);
 59         if (d > res.dis) res.l = r, res.r = a.r, res.dis = d;
 60         return res;
 61     }
 62 
 63     node operator * (const node &a) const {
 64         node res; res.dis = -1;
 65         int d = dist(l, a.l);
 66         if (d > res.dis) res.l = l, res.r = a.l, res.dis = d;
 67         d = dist(l, a.r);
 68         if (d > res.dis) res.l = l, res.r = a.r, res.dis = d;
 69         d = dist(r, a.l);
 70         if (d > res.dis) res.l = r, res.r = a.l, res.dis = d;
 71         d = dist(r, a.r);
 72         if (d > res.dis) res.l = r, res.r = a.r, res.dis = d;
 73         return res;
 74     }
 75 }tr[N << 2];
 76 
 77 node ask(int s, int t) {
 78     node res; res.dis = -1;
 79     for (s += M - 1, t += M + 1; s ^ t ^ 1; s >>= 1, t >>= 1) {
 80         if (~s&1) res = res + tr[s ^ 1];
 81         if ( t&1) res = res + tr[t ^ 1];
 82     }
 83     return res;
 84 }
 85 
 86 int main() {
 87     scanf("%d", &n);
 88     for (int u, v, w, i = 1; i < n; i ++)
 89         scanf("%d %d %d", &u, &v, &w), add(u, v, w);
 90     dfs(1, 1);
 91 
 92     for (int i = 1; i <= cnt; i ++)
 93         st[0][i] = dfn[i];
 94     for (int i = 1; i < 18; i ++)
 95         for (int j = 1; j <= cnt; j ++)
 96             if (j + (1 << (i - 1)) > cnt) st[i][j] = st[i - 1][j];
 97             else st[i][j] = mmin(st[i - 1][j], st[i - 1][j + (1 << (i - 1))]); 
 98     log_2[1] = 0;
 99     for (int i = 2; i <= cnt; i ++)
100         log_2[i] = log_2[i - 1] + (i == (1 << (log_2[i - 1] + 1)));
101 
102     for (M = 1; M < n + 2; M <<= 1);
103     for (int i = 1; i <= n; i ++) tr[i + M] = (node){i, i, 0};
104     for (int i = n + 1; i <= M + 1; i ++) tr[i + M].dis = -1;
105     for (int i = M; i; i --) tr[i] = tr[i << 1] + tr[i << 1 | 1]; 
106 
107     node tmp; int a, b, c, d;
108     for (scanf("%d", &m); m --; ) {
109         scanf("%d %d %d %d", &a, &b, &c, &d);
110         tmp = ask(a, b) * ask(c, d);
111         printf("%d
", tmp.dis);
112     }
113     return 0;
114 }
View Code

相对简单一点了

原文地址:https://www.cnblogs.com/ytytzzz/p/9674661.html