ACM之路(19)—— 主席树初探

  长春赛的 I 题是主席树,现在稍微的学了一点主席树,也就算入了个门吧= =

  简单的来说主席树就是每个节点上面都是一棵线段树,但是这么多线段树会MLE吧?其实我们解决的办法就是有重复的节点给他利用起来,具体见幻神博客

  不妨以1~n上的求任意区间第k小的问题,就是上面博客中所写,我们从1访问到n的预处理中,每一个时间都新建一个线段树,这棵树上记录着已经出现的各个数字,这样我们求[L,R]上的第k小,我们拿R时刻的线段树减去(L-1)时刻的线段树,就是这个区间内需要的线段树,这个线段树上存在的数字其实就是[L,R]上存在的数字,我们在这里寻找我们需要的第k小就可以了。具体实现方法见上面的博客。

  我自己的模板如下:

 1 #include <stdio.h>
 2 #include <algorithm>
 3 #include <string.h>
 4 #define t_mid (l+r>>1)
 5 using namespace std;
 6 const int N = 100000 + 5;
 7 
 8 int n,q,tot,sz;
 9 int a[N],b[N];
10 int rt[N*20],sum[N*20],ls[N*20],rs[N*20];
11 void build(int &o,int l,int r)
12 {
13     o = ++tot;
14     sum[o] = 0;
15     if(l==r) return;
16     build(ls[o],l,t_mid);
17     build(rs[o],t_mid+1,r);
18 }
19 
20 void update(int &o,int l,int r,int last,int p)
21 {
22     o = ++tot;
23     ls[o] = ls[last];
24     rs[o] = rs[last];
25     sum[o] = sum[last] + 1;
26     if(l==r) return;
27     if(p <= t_mid) update(ls[o],l,t_mid,ls[last],p);
28     else update(rs[o],t_mid+1,r,rs[last],p);
29 }
30 
31 int query(int ql,int qr,int l,int r,int k)
32 {
33     if(l==r) return l;
34     int cnt = sum[ls[qr]] - sum[ls[ql]];
35     if(cnt >= k) return query(ls[ql],ls[qr],l,t_mid,k);
36     else return query(rs[ql],rs[qr],t_mid+1,r,k-cnt);
37 }
38 
39 void work()
40 {
41     int ql,qr,k;
42     scanf("%d%d%d",&ql,&qr,&k);
43     int ans = query(rt[ql-1],rt[qr],1,sz,k);
44     printf("%d
",b[ans]);
45 }
46 
47 int main()
48 {
49     while(scanf("%d%d",&n,&q)==2)
50     {
51         tot = 0;
52         for(int i=1;i<=n;i++) scanf("%d",a+i),b[i]=a[i];
53         sort(b+1,b+1+n);
54         sz = unique(b+1,b+1+n) - (b+1);
55         build(rt[0],1,sz);
56         
57         for(int i=1;i<=n;i++)
58         {
59             int t = lower_bound(b+1,b+1+sz,a[i]) - b;
60             update(rt[i],1,sz,rt[i-1],t);
61         }
62         while(q--) work();
63     }
64 }
求区间第K小

  然后如果是在一棵树上,求其一条链上的区间第k小呢?其实也差不多,我们就想着怎么把这棵需要的线段树抽取出来就行。这棵树实际上就是 u - lca(u,v) + v - father(lca(u,v))。具体的画画图就可以懂了。这里还涉及到求LCA的方法,具体方法见《挑战程序设计》中的倍增法求LCA即可。

  我自己的模板如下:

  1 #include <stdio.h>
  2 #include <algorithm>
  3 #include <string.h>
  4 #include <vector>
  5 #include <math.h>
  6 #define t_mid (l+r>>1)
  7 using namespace std;
  8 const int N = 100000 + 5;
  9 const int MAX_LOG_N = 16 + 5;
 10 
 11 int n,q,tot,sz;
 12 int a[N],b[N];
 13 int rt[N*20],sum[N*20],ls[N*20],rs[N*20];
 14 int parent[MAX_LOG_N][N],depth[N];
 15 vector<int> G[N];
 16 
 17 void getDepth(int v,int p,int d)
 18 {
 19     parent[0][v] = p;
 20     depth[v] = d;
 21     for(int i=0;i<G[v].size();i++)
 22     {
 23         if(G[v][i] != p) getDepth(G[v][i],v,d+1);
 24     }
 25 }
 26 
 27 void init()
 28 {
 29     getDepth(1,-1,0);
 30     for(int k=0;k+1<MAX_LOG_N;k++)
 31     {
 32         for(int v=1;v<=n;v++)
 33         {
 34             if(parent[k][v] < 0) parent[k+1][v] = -1;
 35             else parent[k+1][v] = parent[k][parent[k][v]];
 36         }
 37     }
 38 }
 39 
 40 int lca(int u,int v)
 41 {
 42     if(depth[u]>depth[v]) swap(u,v);
 43     for(int k=0;k<MAX_LOG_N;k++)
 44     {
 45         if((depth[v]-depth[u]) >> k & 1)
 46         {
 47             v = parent[k][v];
 48         }
 49     }
 50     if(u==v) return u;
 51     for(int k=MAX_LOG_N-1;k>=0;k--)
 52     {
 53         if(parent[k][u] != parent[k][v])
 54         {
 55             u = parent[k][u];
 56             v = parent[k][v];
 57         }
 58     }
 59     return parent[0][u];
 60 }
 61 
 62 void build(int &o,int l,int r)
 63 {
 64     o = ++tot;
 65     sum[o] = 0;
 66     if(l==r) return;
 67     build(ls[o],l,t_mid);
 68     build(rs[o],t_mid+1,r);
 69 }
 70 
 71 void update(int &o,int l,int r,int last,int p)
 72 {
 73     o = ++tot;
 74     ls[o] = ls[last];
 75     rs[o] = rs[last];
 76     sum[o] = sum[last] + 1;
 77     if(l==r) return;
 78     if(p <= t_mid) update(ls[o],l,t_mid,ls[last],p);
 79     else update(rs[o],t_mid+1,r,rs[last],p);
 80 }
 81 
 82 int query(int u,int v,int x,int y,int l,int r,int k)
 83 {
 84     if(l==r) return l;
 85     int cnt = sum[ls[u]] + sum[ls[v]] - sum[ls[x]] - sum[ls[y]];
 86     if(cnt >= k) return query(ls[u],ls[v],ls[x],ls[y],l,t_mid,k);
 87     else return query(rs[u],rs[v],rs[x],rs[y],t_mid+1,r,k-cnt);
 88 }
 89 
 90 void work()
 91 {
 92     int u,v,k;
 93     scanf("%d%d%d",&u,&v,&k);
 94     int _lca = lca(u,v);
 95     int _lca_fa = parent[0][_lca];
 96     int ans = query(rt[u],rt[v],rt[_lca],rt[_lca_fa],1,sz,k);
 97     printf("%d
",b[ans]);
 98 }
 99 
100 void dfs(int u,int fa)
101 {
102     for(int i=0;i<G[u].size();i++)
103     {
104         int v = G[u][i];
105         if(v==fa) continue;
106         int t = lower_bound(b+1,b+1+sz,a[v]) - b;
107         update(rt[v],1,sz,rt[u],t);
108         dfs(v,u);
109     }
110 }
111 
112 int main()
113 {
114     while(scanf("%d%d",&n,&q)==2)
115     {
116         tot = 0;
117         for(int i=1;i<=n;i++) G[i].clear();
118         for(int i=1;i<=n;i++) scanf("%d",a+i),b[i]=a[i];
119         sort(b+1,b+1+n);
120         sz = unique(b+1,b+1+n) - (b+1);
121         for(int i=1;i<n;i++)
122         {
123             int u,v;scanf("%d%d",&u,&v);
124             G[u].push_back(v);
125             G[v].push_back(u);
126         }
127         build(rt[0],1,sz);
128         init();
129         
130         int t = lower_bound(b+1,b+1+sz,a[1]) - b;
131         update(rt[1],1,sz,rt[0],t);
132         dfs(1,-1);
133         
134         while(q--) work();
135     }
136 }
求树上的一条链的第K小

  好,接下来就是解决那个烦人的 I 题了。

  我们首先需要用主席树来解决区间内不同的数的个数,这东西比较奥义- -直接上模板好了。。反正随便百度一下"主席树求区间内不同数的个数"都会出来spoj的D-query那题,随便看下原理就行= =。。。然后用二分解决 I 题(固定左端点,二分右端点,具体见代码。。)。

  看我直接丢 I 题的代码~:

  1 #include <stdio.h>
  2 #include <algorithm>
  3 #include <string.h>
  4 #include <map>
  5 #define t_mid (l+r>>1)
  6 using namespace std;
  7 const int N = 2*100000 + 50;
  8 
  9 int rt[N*20*2],sum[N*20*2],ls[N*20*2],rs[N*20*2];
 10 int a[N],n,m,tot;
 11 void build(int &o,int l,int r)
 12 {
 13     o = ++tot;
 14     sum[o] = 0;
 15     if(l == r) return;
 16     build(ls[o],l,t_mid);
 17     build(rs[o],t_mid+1,r);
 18 }
 19 
 20 void update(int &o,int l,int r,int last,int pos,int dt)
 21 {
 22     o = ++tot;
 23     sum[o] = sum[last];
 24     ls[o] = ls[last];
 25     rs[o] = rs[last];
 26     if(l==r) {sum[o]+=dt;return;}
 27     if(pos <= t_mid) update(ls[o],l,t_mid,ls[last],pos,dt);
 28     else update(rs[o],t_mid+1,r,rs[last],pos,dt);
 29     sum[o] = sum[ls[o]] + sum[rs[o]];
 30 }
 31 
 32 int query(int l,int r,int o,int pos)
 33 {
 34     if(l == r) return sum[o];
 35     if(pos <= t_mid) return sum[rs[o]] + query(l,t_mid,ls[o],pos);
 36     else return query(t_mid+1,r,rs[o],pos);
 37 }
 38 
 39 /*
 40 int query(int l,int r,int L,int R,int x){
 41     if(L <= l && r <= R) return sum[x];
 42     int mid = (l+r) >> 1 , ret = 0;
 43     if(L <= mid) ret += query(l,mid,L,R,ls[x]);
 44     if(R > mid) ret += query(mid+1,r,L,R,rs[x]);
 45     return ret;
 46 }
 47 */
 48 
 49 int main()
 50 {
 51     int T;scanf("%d",&T);
 52     for(int kase=1;kase<=T;kase++)
 53     {
 54         scanf("%d%d",&n,&m);
 55         int pre = 0;
 56         map<int,int> mp;
 57         tot = 0;
 58         for(int i=1;i<=n;i++) scanf("%d",a+i);
 59         build(rt[0],1,n);
 60 
 61         for(int i=1;i<=n;i++)
 62         {
 63             if(mp.find(a[i]) == mp.end())
 64             {
 65                 mp[a[i]] = i;
 66                 update(rt[i],1,n,rt[i-1],i,1);
 67             }
 68             else
 69             {
 70                 int temp = 0;
 71                 update(temp,1,n,rt[i-1],mp[a[i]],-1);
 72                 update(rt[i],1,n,temp,i,1);
 73             }
 74             mp[a[i]] = i;
 75         }
 76         //scanf("%d",&m);
 77         printf("Case #%d:",kase);
 78         while(m--)
 79         {
 80             int ql,qr;scanf("%d%d",&ql,&qr);
 81             int L = min((ql+pre)%n+1,(qr+pre)%n+1);
 82             int R = max((ql+pre)%n+1,(qr+pre)%n+1);
 83             //L = ql, R = qr;
 84             int k = (query(1,n,rt[R],L)+1)>>1;
 85             int l = L, r = R;
 86             //printf("!! %d %d 
",L,R);
 87             int ans = -1;
 88             while(l<=r)
 89             {
 90                 int mid = l + r >> 1;
 91                 int t = query(1,n,rt[mid],L);
 92                 //printf("mid is %d %d
",mid,t);
 93                 if(t < k) l = mid + 1;
 94                 else
 95                 {
 96                     r = mid - 1;
 97                     ans = mid;
 98                 }
 99             }
100             /*while(l < r)
101             {
102                 int mid = l + r >> 1;
103                 int t = query(1,n,rt[mid],L);
104                 if(t < k) l = mid + 1;
105                 else r = mid;
106             }*/
107             
108             printf(" %d",ans);
109             pre = ans;
110         }
111         puts("");
112     }
113 }
114 
115 /*
116 100
117 20 100
118 1 2 3 4 3 2 1 2 4 2 2 3 1 2 3 1 4 4 2 1
119 1 20
120 1 10
121 2 5
122 4 6
123 3 2
124 4 7
125 
126 100
127 5 100
128 0 1 0 2 3
129 1 5
130 */
131 /*
132 #include<iostream>
133 //#include<bits/stdc++.h>
134 #include<cstdio>
135 #include<string>
136 #include<cstring>
137 #include<map>
138 #include<queue>
139 #include<set>
140 #include<stack>
141 #include<ctime>
142 #include<algorithm>
143 #include<cmath>
144 #include<vector>
145 #define showtime fprintf(stderr,"time = %.15f
",clock() / (double)CLOCKS_PER_SEC)
146 //#pragma comment(linker, "/STACK:1024000000,1024000000")
147 using namespace std;
148 typedef long long ll;
149 typedef long long LL;
150 #define MP make_pair
151 #define PII pair<int,int>
152 #define PLI pair<long long ,int>
153 #define PFI pair<double,int>
154 #define PLL pair<ll,ll>
155 #define PB push_back
156 #define F first
157 #define S second
158 #define lson l,mid,rt<<1
159 #define rson mid+1,r,rt<<1|1
160 #define debug cout<<"?????"<<endl;
161 //freopen("1005.in","r",stdin);
162 //freopen("data.out","w",stdout);
163 const int INF = 0x3f3f3f3f;
164 const double eps = 1e-2;
165 const int N = 4e5 + 50 ;
166 const double PI = acos(-1.);
167 const double E = 2.71828182845904523536;
168 const int MOD = 1e9+7;
169 typedef vector<ll> Vec;
170 typedef vector<Vec> Mat;
171 int n,m;
172 struct node{int l,r,sum;}T[N*40];
173 int a[N],root[N],pre[N],tot;
174 int q,x,y;
175 int ans[N];
176 vector<int> v;
177 int getid(int x){ return lower_bound(v.begin(),v.end(),x) - v.begin() + 1;}
178 void init(){
179     tot = 0;
180     memset(root,0,sizeof(root));
181     memset(pre,-1,sizeof(pre));
182     v.clear();
183 }
184 void update(int l,int r,int val,int &x,int y,int pos){
185     T[++tot] = T[y] , T[tot].sum += val , x = tot;
186     if(l == r) return ;
187     int mid = (l + r) >> 1;
188     if(pos <= mid) update(l,mid,val,T[x].l,T[y].l,pos);
189     else update(mid+1,r,val,T[x].r,T[y].r,pos);
190 }
191 **
192  *        【x=L,y=R】 不同数字的有多少个
193  *        query(1,n,x,y,root[y]);  第y颗树。
194  *
195 int query(int l,int r,int L,int R,int x){
196     if(L <= l && r <= R) return T[x].sum;
197     int mid = (l+r) >> 1 , ret = 0;
198     if(L <= mid) ret += query(l,mid,L,R,T[x].l);
199     if(R > mid) ret += query(mid+1,r,L,R,T[x].r);
200     return ret;
201 }
202 int main(){
203     int kase = 1,T;
204     cin >> T;
205     while(T --){
206         cin >> n >> m;
207         init();
208         for(int i = 1 ; i <= n ; i ++) scanf("%d",&a[i]) , v.push_back(a[i]);
209         sort(v.begin(),v.end());
210         v.erase(unique(v.begin(),v.end()),v.end());
211         for(int i = 1 ; i <= n ; i ++){
212             int id = getid(a[i]);
213             if(pre[id] == -1){
214                 update(1,n,1,root[i],root[i-1],i);
215                 pre[id] = i;
216             }else{
217                 int tmp;
218                 update(1,n,-1,tmp,root[i-1],pre[id]);
219                 update(1,n,1,root[i],tmp,i);
220                 pre[id] = i;
221             }
222         }
223         ans[0] = 0;
224         printf("Case #%d:",kase ++);
225         for(int i = 1 ; i <= m ; i ++){
226             scanf("%d%d",&x,&y);
227             int l,r;
228             l = min((x+ans[i-1])%n+1,(y+ans[i-1])%n+1);
229             r = max((x+ans[i-1])%n+1,(y+ans[i-1])%n+1);
230             //l = x ; r = y;
231             //printf("%d %d !!
",l,r);
232             int k = (query(1,n,l,r,root[r])+1) / 2;
233             int ll = l , rr = r;
234             while(ll < rr){
235                 int mid = (ll + rr) / 2;
236                 int t = query(1,n,l,mid,root[mid]);
237                 if(t < k) ll = mid+1;
238                 else rr = mid;
239             }
240             printf(" %d",rr);
241             ans[i] = rr;
242         }
243         puts("");
244     }
245     return 0;
246 }
247 */
长春 I 题

  有几点想说明的:1.下面注释的是大力的代码,但是超时了,因为他的query方法和我的有点小差别,虽然都能实现需要的功能,但是似乎我的query方法复杂度更小一点(??)。。不过我的也是卡过的,但是我觉得在长春现场赛的话应该能过,感觉HDU的评测机这次有点坑。。2.我的代码本来是WA的,因为数组开小了,我上面的两个代码都是*20的,都没问题,这里必须要开*40的才行,被坑了这一次以后我下次都开大一点的好了,反正*40内存也够用= =。。那么主席树就写到这里好了,以后刷了题目有什么要补充的再补充好了~(话说我的数据结构真的好烂啊,,以后搞splay怎么办啊233。。)

原文地址:https://www.cnblogs.com/zzyDS/p/5931453.html