brz的树【牛客练习赛72 F】【虚树+dfs序可持久化线段树+树上差分】

题目链接

  有N个点的树,现在我们有M次询问,询问(u、v),只在u、v子树内出现的颜色的个数(u子树 并 v子树)。

  首先,可以将问题拆开来讨论,如果对于一个点的时候是怎样的呢?

  先求一个点时候的答案:

  那么,实际上就是求“颜色总数-不在子树内出现的颜色总数”。那么,实际上,我们可以将点数扩展成2N这么多个,然后查询出现在dfn[u] + siz[u] ~ dfn[u] + N - 1内的颜色的个数。当然,不嫌麻烦的话,就不需要这么做,直接写树剖,然后找到每个颜色的总的父亲节点,给1~point这条链全部“+1”,然后之后单点查询即可。——于是,我们就可以求得每个点的自身子树所产生的答案了。

  然后考虑两个点的时候,当且仅当我们对一个颜色构建虚树,虚树有虚根(虚根颜色不为该颜色),且虚根有且仅有两个虚子节点的时候,才需要考虑这种颜色的贡献。

  于是,求两个点时候的答案:

  我们不妨对虚子节点进行配对,譬如说“u指向v”这样的,于是,对于查询的x、y,我们实际上就是想知道x的子树内,有多少是指向y子树内的。那么,对于x的子树区间为dfn[x] ~ dfn[x] + siz[x] - 1,同理,y的子树。那么,比较显然的,这不就是用dfs序来构造一个主席树,然后在主席树上进行差分,找到x子树对于y子树的贡献吗?

  所以,我们可以利用虚树来找到这样的u、v对,然后我们让dfs序小的指向dfs序大的,最后,我们对它的dfs序上建立可持久化线段树,然后对于查询的x、y对,我们用dfs序小的作为根的差分,去查dfs序大的部分。——于是,利用这些,我们就可以知道这样的关系对对每个答案产生的贡献了。

  1 #include <iostream>
  2 #include <cstdio>
  3 #include <cmath>
  4 #include <string>
  5 #include <cstring>
  6 #include <algorithm>
  7 #include <limits>
  8 #include <vector>
  9 #include <stack>
 10 #include <queue>
 11 #include <set>
 12 #include <map>
 13 #include <bitset>
 14 #include <unordered_map>
 15 #include <unordered_set>
 16 #define lowbit(x) ( x&(-x) )
 17 #define pi 3.141592653589793
 18 #define e 2.718281828459045
 19 #define INF 0x3f3f3f3f
 20 #define HalF (l + r)>>1
 21 #define lsn rt<<1
 22 #define rsn rt<<1|1
 23 #define Lson lsn, l, mid
 24 #define Rson rsn, mid+1, r
 25 #define QL Lson, ql, qr
 26 #define QR Rson, ql, qr
 27 #define myself rt, l, r
 28 #define pii pair<int, int>
 29 #define MP(a, b) make_pair(a, b)
 30 using namespace std;
 31 typedef unsigned long long ull;
 32 typedef unsigned int uit;
 33 typedef long long ll;
 34 const int maxN = 1e5 + 7;
 35 int N, M, head[maxN], cnt, col[maxN], lsan[maxN], _UP;
 36 vector<int> vt[maxN];
 37 struct Eddge
 38 {
 39     int nex, to;
 40     Eddge(int a=-1, int b=0):nex(a), to(b) {}
 41 } edge[maxN << 1];
 42 inline void addEddge(int u, int v)
 43 {
 44     edge[cnt] = Eddge(head[u], v);
 45     head[u] = cnt ++;
 46 }
 47 inline void _add(int u, int v) { addEddge(u, v); addEddge(v, u); }
 48 inline void init()
 49 {
 50     cnt = 0;
 51     for(int i = 1; i <= N; i ++) head[i] = -1;
 52 }
 53 int deep[maxN], fa[maxN][20], LOG_2[maxN], siz[maxN], dfn[maxN], idx, rid[maxN << 1];
 54 bool cmp(int a, int b) { return dfn[a] < dfn[b]; }
 55 void dfs(int u, int father)
 56 {
 57     siz[u] = 1;
 58     fa[u][0] = father;
 59     deep[u] = deep[father] + 1;
 60     dfn[u] = ++idx;
 61     rid[idx] = u;
 62     for(int i = 0; i < 18; i ++) fa[u][i + 1] = fa[fa[u][i]][i];
 63     for(int i = head[u], v; ~i; i = edge[i].nex)
 64     {
 65         v = edge[i].to;
 66         if(v == father) continue;
 67         dfs(v, u);
 68         siz[u] += siz[v];
 69     }
 70 }
 71 int lca(int u, int v)
 72 {
 73     if(deep[u] < deep[v]) swap(u, v);
 74     int det = deep[u] - deep[v];
 75     for(int i = LOG_2[det]; i >= 0; i --)
 76     {
 77         if((det >> i) & 1) u = fa[u][i];
 78     }
 79     if(u == v) return u;
 80     for(int i = LOG_2[deep[v]]; i >= 0; i --)
 81     {
 82         if(fa[u][i] ^ fa[v][i])
 83         {
 84             u = fa[u][i];
 85             v = fa[v][i];
 86         }
 87     }
 88     return fa[u][0];
 89 }
 90 int stk[maxN], top;
 91 vector<int> son;
 92 void Insert(int u)
 93 {
 94     if(top <= 1)
 95     {
 96         stk[++ top] = u;
 97         return;
 98     }
 99     int p = lca(u, stk[top]);
100     while(top >= 2 && dfn[p] <= dfn[stk[top - 1]])
101     {
102         if(top == 2) son.push_back(stk[top]);
103         top --;
104     }
105     if(stk[top] ^ p)
106     {
107         stk[top] = p;
108     }
109     stk[++ top] = u;
110 }
111 vector<int> nex[maxN];
112 void solve(int op)
113 {
114     sort(vt[op].begin(), vt[op].end(), cmp);
115     int root = vt[op][0];
116     for(int u : vt[op]) root = lca(root, u);
117     top = 0; son.clear();
118     Insert(root);
119     for(int u : vt[op])
120     {
121         if(u == root)
122         {
123             top = 0;
124             son.clear();
125             break;
126         }
127         Insert(u);
128     }
129     if(top >= 2) son.push_back(stk[2]);
130     if(son.size() == 2)
131     {
132         int x = son[0], y = son[1];
133         if(dfn[x] > dfn[y]) swap(x, y);
134         nex[dfn[x]].push_back(dfn[y]);
135     }
136 }
137 int t[maxN << 1];
138 void add(int x, int v) { while(x <= (N << 1)) { t[x] += v; x += lowbit(x); } }
139 int sum(int x) { int res = 0; while(x) { res += t[x]; x -= lowbit(x); } return res; }
140 vector<pii> ques[maxN << 1];
141 int las_col[maxN] = {0};
142 int ans[maxN];
143 namespace Segement
144 {
145     const int maxP = maxN * 30;
146     int tree[maxP], lc[maxP], rc[maxP];
147     int root[maxN], tot;
148     void build(int &rt, int old, int l, int r, int qx)
149     {
150         rt = ++ tot;
151         lc[rt] = lc[old]; rc[rt] = rc[old]; tree[rt] = tree[old] + 1;
152         if(l == r) return;
153         int mid = HalF;
154         if(qx <= mid) build(lc[rt], lc[old], l, mid, qx);
155         else build(rc[rt], rc[old], mid + 1, r, qx);
156     }
157     int query(int rl, int rr, int l, int r, int ql, int qr)
158     {
159         if(ql <= l && qr >= r) return tree[rr] - tree[rl];
160         int mid = HalF;
161         if(qr <= mid) return query(lc[rl], lc[rr], l, mid, ql, qr);
162         else if(ql > mid) return query(rc[rl], rc[rr], mid + 1, r, ql, qr);
163         else return query(lc[rl], lc[rr], l, mid, ql, qr) + query(rc[rl], rc[rr], mid + 1, r, ql, qr);
164     }
165 }
166 using namespace Segement;
167 struct Question
168 {
169     int u, v, id;
170     Question(int a=0, int b=0, int c=0):u(a), v(b), id(c) {}
171 };
172 vector<Question> qt;
173 int main()
174 {
175     for(int i = 2; i < maxN; i ++) LOG_2[i] = LOG_2[i >> 1] + 1;
176     scanf("%d%d", &N, &M);
177     init();
178     for(int i = 1; i <= N; i ++) { scanf("%d", &col[i]); lsan[i] = col[i]; }
179     sort(lsan + 1, lsan + N + 1);
180     _UP = (int)(unique(lsan + 1, lsan + N + 1) - lsan - 1);
181     for(int i = 1; i <= N; i ++)
182     {
183         col[i] = (int)(lower_bound(lsan + 1, lsan + _UP + 1, col[i]) - lsan);
184         vt[col[i]].push_back(i);
185     }
186     for(int i = 1, u, v; i < N; i ++)
187     {
188         scanf("%d%d", &u, &v);
189         _add(u, v);
190     }
191     dfs(1, 0);
192     for(int i = 1; i <= N; i ++) rid[N + i] = rid[i];
193     for(int i = 1; i <= _UP; i ++) solve(i);
194     for(int i = 1, x, y, p, l, r; i <= M; i ++)
195     {
196         scanf("%d%d", &x, &y);
197         if(dfn[x] > dfn[y]) swap(x, y);
198         p = lca(x, y);
199         if(p == x)
200         {
201             ans[i] = _UP;
202             l = dfn[p] + siz[p]; r = N + dfn[p] - 1;
203             ques[r].push_back(MP(l, i));
204         }
205         else
206         {
207             ans[i] = _UP << 1;
208             l = dfn[x] + siz[x]; r = N + dfn[x] - 1;
209             ques[r].push_back(MP(l, i));
210             l = dfn[y] + siz[y]; r = N + dfn[y] - 1;
211             ques[r].push_back(MP(l, i));
212             qt.push_back(Question(x, y, i));
213         }
214     }
215     for(int i = 1, u, c, id, lx; i <= (N << 1); i ++)
216     {
217         u = rid[i];
218         c = col[u];
219         add(i, 1);
220         if(las_col[c]) add(las_col[c], -1);
221         las_col[c] = i;
222         for(pii it : ques[i])
223         {
224             id = it.second;
225             lx = it.first;
226             ans[id] -= sum(i) - sum(lx - 1);
227         }
228     }
229     for(int i = 1; i <= N; i ++)
230     {
231         root[i] = root[i - 1];
232         for(int j : nex[i])
233         {
234             build(root[i], root[i], 1, N, j);
235         }
236     }
237     for(Question it : qt)
238     {
239         int u = it.u, v = it.v, id = it.id;
240         int l = dfn[u], r = dfn[u] + siz[u] - 1, ql = dfn[v], qr = dfn[v] + siz[v] - 1;
241         ans[id] += query(root[l - 1], root[r], 1, N, ql, qr);
242     }
243     for(int i = 1; i <= M; i ++) printf("%d
", ans[i]);
244     return 0;
245 }
原文地址:https://www.cnblogs.com/WuliWuliiii/p/14201498.html