【HDOJ】4358 Boring counting

基本思路是将树形结构转线性结构,因为查询的是从任意结点到叶子结点的路径。
从而将每个查询转换成区间,表示从该结点到叶子结点的路径。
离线做,按照右边界升序排序。
利用树状数组区间修改。
树状数组表示有K个数据的数量,利用pos进行维护。
假设现有的sz >= K, 那么需要对区间进行修改。

  1 /* 4358 */
  2 #include <iostream>
  3 #include <sstream>
  4 #include <string>
  5 #include <map>
  6 #include <queue>
  7 #include <set>
  8 #include <stack>
  9 #include <vector>
 10 #include <deque>
 11 #include <algorithm>
 12 #include <cstdio>
 13 #include <cmath>
 14 #include <ctime>
 15 #include <cstring>
 16 #include <climits>
 17 #include <cctype>
 18 #include <cassert>
 19 #include <functional>
 20 #include <iterator>
 21 #include <iomanip>
 22 using namespace std;
 23 //#pragma comment(linker,"/STACK:102400000,1024000")
 24 
 25 #define sti                set<int>
 26 #define stpii            set<pair<int, int> >
 27 #define mpii            map<int,int>
 28 #define vi                vector<int>
 29 #define pii                pair<int,int>
 30 #define vpii            vector<pair<int,int> >
 31 #define rep(i, a, n)     for (int i=a;i<n;++i)
 32 #define per(i, a, n)     for (int i=n-1;i>=a;--i)
 33 #define clr                clear
 34 #define pb                 push_back
 35 #define mp                 make_pair
 36 #define fir                first
 37 #define sec                second
 38 #define all(x)             (x).begin(),(x).end()
 39 #define SZ(x)             ((int)(x).size())
 40 #define lson            l, mid, rt<<1
 41 #define rson            mid+1, r, rt<<1|1
 42 
 43 typedef struct {
 44     int v, nxt;
 45 } edge_t;
 46 
 47 typedef struct node_t {
 48     int w, id; 
 49     
 50     friend bool operator< (const node_t& a, const node_t& b) {
 51         if (a.w == b.w)
 52             return a.id < b.id;
 53         return a.w < b.w;
 54     }
 55     
 56 } node_t;
 57 
 58 typedef struct ques_t {
 59     int l, r, id;
 60 } ques_t;
 61 
 62 const int maxn = 1e5+5;
 63 const int maxv = maxn;
 64 const int maxe = maxv * 2;
 65 int head[maxv], l;
 66 edge_t E[maxe];
 67 int Beg[maxn], End[maxn];
 68 int val[maxn], W[maxn];
 69 node_t nd[maxn];
 70 vi pvc[maxn];
 71 int dfs_clock;
 72 int n, K;
 73 int a[maxn];
 74 ques_t Q[maxn];
 75 int ans[maxn];
 76     
 77 bool compq (const ques_t& a, const ques_t& b) {
 78     if (a.r == b.r)
 79         return a.l < b.l;
 80     return a.r < b.r;
 81 }
 82 
 83 void init() {
 84     memset(head, -1, sizeof(head));
 85     memset(a, 0, sizeof(a));
 86     dfs_clock = l = 0;
 87 }
 88 
 89 void addEdge(int u, int v) {
 90     E[l].v = v;
 91     E[l].nxt = head[u];
 92     head[u] = l++;
 93     
 94     E[l].v = u;
 95     E[l].nxt = head[v];
 96     head[v] = l++;
 97 }
 98 
 99 void dfs(int u, int fa) {
100     int v, k;
101     
102     Beg[u] = ++dfs_clock;
103     val[dfs_clock] = W[u];
104     for (k=head[u]; k!=-1; k=E[k].nxt) {
105         v = E[k].v;
106         if (v == fa)
107             continue;
108         dfs(v, u);
109     }
110     End[u] = dfs_clock;
111 }
112 
113 int lowest(int x) {
114     return x & -x;
115 }
116 
117 int sum(int x) {
118     int ret = 0;
119     
120     while (x) {
121         ret += a[x];
122         x -= lowest(x);
123     }
124     
125     return ret;
126 }
127 
128 void update(int x, int delta) {
129     while (x <= n) {
130         a[x] += delta;
131         x += lowest(x);
132     }
133 }
134 
135 void solve() {
136     int q;
137     int u, v;
138     
139     sort(nd+1, nd+1+n);
140     int cnt = 1;
141     W[nd[1].id] = cnt;
142     rep(i, 2, n+1) {
143         if (nd[i].w==nd[i-1].w) {
144             W[nd[i].id] = cnt;
145         } else {
146             W[nd[i].id] = ++cnt;
147         }
148     }
149     
150     dfs(1, 0);
151     
152     rep(i, 0, cnt+1) {
153         pvc[i].clr();
154         pvc[i].pb(0);
155     }
156     
157     scanf("%d", &q);
158     rep(i, 0, q) {
159         scanf("%d", &u);
160         Q[i].l = Beg[u];
161         Q[i].r = End[u];
162         Q[i].id = i;
163     }
164     
165     sort(Q, Q+q, compq);
166     int sz;
167     int j = 0;
168     
169     rep(i, 1, n+1) {
170         pvc[val[i]].pb(i);
171         sz = SZ(pvc[val[i]]) - 1;
172         if (sz >= K) {
173             if (sz > K) {
174                 update(pvc[val[i]][sz-K-1]+1, -1);
175                 update(pvc[val[i]][sz-K]+1, 1);
176             }
177             update(pvc[val[i]][sz-K]+1, 1);
178             update(pvc[val[i]][sz-K+1]+1, -1);
179         }
180         while (j<q && Q[j].r==i) {
181             ans[Q[j].id] = sum(Q[j].l);
182             ++j;
183         }
184     }
185     
186     rep(i, 0, q)
187         printf("%d
", ans[i]);
188 }
189 
190 int main() {
191     ios::sync_with_stdio(false);
192     #ifndef ONLINE_JUDGE
193         freopen("data.in", "r", stdin);
194         freopen("data.out", "w", stdout);
195     #endif
196     
197     int t;
198     int u, v;
199     
200     scanf("%d", &t);
201     rep(tt, 1, t+1) {
202         init();
203         scanf("%d %d", &n, &K);
204         rep(i, 1, n+1) {
205             scanf("%d", &W[i]);
206             nd[i].id = i;
207             nd[i].w = W[i];
208         }
209         rep(i, 1, n) {
210             scanf("%d %d", &u, &v);
211             addEdge(u, v);
212         }
213         printf("Case #%d:
", tt);
214         solve();
215         if (tt != t)
216             putchar('
');
217     }
218     
219     #ifndef ONLINE_JUDGE
220         printf("time = %d.
", (int)clock());
221     #endif
222     
223     return 0;
224 }

数据发生器。

 1 from copy import deepcopy
 2 from random import randint, shuffle
 3 import shutil
 4 import string
 5 
 6 
 7 def GenDataIn():
 8     with open("data.in", "w") as fout:
 9         t = 10
10         bound = 10**9
11         fout.write("%d
" % (t))
12         for tt in xrange(t):
13             n = randint(100, 200)
14             K = randint(1, 5)
15             fout.write("%d %d
" % (n, K))
16             ust = [1]
17             vst = range(2, n+1)
18             L = []
19             for i in xrange(n):
20                 x = randint(1, 100)
21                 L.append(x)
22             fout.write(" ".join(map(str, L)) + "
")
23             for i in xrange(1, n):
24                 idx = randint(0, len(ust)-1)
25                 u = ust[idx]
26                 idx = randint(0, len(vst)-1)
27                 v = vst[idx]
28                 ust.append(v)
29                 vst.remove(v)
30                 fout.write("%d %d
" % (u, v))
31             q = n
32             fout.write("%d
" % (q))
33             L = range(1, n+1)
34             shuffle(L)
35             fout.write("
".join(map(str, L)) + "
")
36                 
37                 
38 def MovDataIn():
39     desFileName = "F:eclipse_prjworkspacehdojdata.in"
40     shutil.copyfile("data.in", desFileName)
41 
42     
43 if __name__ == "__main__":
44     GenDataIn()
45     MovDataIn()
原文地址:https://www.cnblogs.com/bombe1013/p/5191707.html