【模板】【伸展树】Splay Tree小结

  先发两个模板,都是网上找的略作修改后的。

  1.普通版。支持lazy操作、重复值、找前驱后继、单个删除、值区间删除等。代码对应的题是BZOJ1588 - 营业额统计。

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 const int MAXN = 100011;
  4 
  5 struct SplayTree{
  6     int cnt, rt;
  7     int Add[MAXN];
  8 
  9     struct Node{
 10         int key, cnt, size, fa, son[2];
 11     }T[MAXN];
 12 
 13     inline void PushUp(int x){
 14         T[x].size=T[T[x].son[0]].size+T[T[x].son[1]].size+T[x].cnt;
 15     }
 16 
 17     inline void PushDown(int x){
 18         if(Add[x]){
 19             if(T[x].son[0]){
 20                 T[T[x].son[0]].key+=Add[x];
 21                 Add[T[x].son[0]]+=Add[x];
 22             }
 23             if(T[x].son[1]){
 24                 T[T[x].son[1]].key+=Add[x];
 25                 Add[T[x].son[1]]+=Add[x];
 26             }
 27             Add[x]=0;
 28         }
 29     }
 30 
 31     inline int Newnode(int key, int fa){ //新建一个节点并返回
 32         ++cnt;
 33         T[cnt].key=key;
 34         T[cnt].cnt=T[cnt].size=1;
 35         T[cnt].fa=fa;
 36         T[cnt].son[0]=T[cnt].son[1]=0;
 37         return cnt;
 38     }
 39 
 40     inline void Rotate(int x, int p){ //0左旋 1右旋
 41         int y=T[x].fa;
 42         PushDown(y);
 43         PushDown(x);
 44         T[y].son[!p]=T[x].son[p];
 45         T[T[x].son[p]].fa=y;
 46         T[x].fa=T[y].fa;
 47         if(T[x].fa)
 48             T[T[x].fa].son[T[T[x].fa].son[1] == y]=x;
 49         T[x].son[p]=y;
 50         T[y].fa=x;
 51         PushUp(y);
 52         PushUp(x);
 53     }
 54 
 55     void Splay(int x, int to){ //将x节点移动到To的子节点中
 56         while(T[x].fa != to){
 57             if(T[T[x].fa].fa == to)
 58                 Rotate(x, T[T[x].fa].son[0] == x);
 59             else{
 60                 int y=T[x].fa, z=T[y].fa;
 61                 int p=(T[z].son[0] == y);
 62                 if(T[y].son[p] == x)
 63                     Rotate(x, !p), Rotate(x, p); //之字旋
 64                 else
 65                     Rotate(y, p), Rotate(x, p); //一字旋
 66             }
 67         }
 68         if(to == 0) rt=x;
 69     }
 70 
 71     int GetKth(int k){
 72         if(!rt || k > T[rt].size) return -1e9;  // 若要节点id,改为0
 73         int x=rt;
 74         while(x){
 75             PushDown(x);
 76             if(k >= T[T[x].son[0]].size+1 && k <= T[T[x].son[0]].size+T[x].cnt)
 77                 break;
 78             if(k > T[T[x].son[0]].size+T[x].cnt){
 79                 k-=T[T[x].son[0]].size+T[x].cnt;
 80                 x=T[x].son[1];
 81             }
 82             else
 83                 x=T[x].son[0];
 84         }
 85         return T[x].key;   // 若要节点id,改为x
 86     }
 87 
 88     int Find(int key){ //返回值为key的节点 若无返回0 若有将其转移到根处
 89         if(!rt) return 0;
 90         int x=rt;
 91         while(x){
 92             PushDown(x);
 93             if(T[x].key == key) break;
 94             x=T[x].son[key > T[x].key];
 95         }
 96         if(x) Splay(x, 0);
 97         return x;
 98     }
 99 
100     int Prev(){ //返回根节点的前驱 非重点
101         if(!rt || !T[rt].son[0]) return 0;
102         int x=T[rt].son[0];
103         while(T[x].son[1]){
104             PushDown(x);
105             x=T[x].son[1];
106         }
107         Splay(x, 0);
108         return x;
109     }
110 
111     int Succ(){ //返回根结点的后继 非重点
112         if(!rt || !T[rt].son[1]) return 0;
113         int x=T[rt].son[1];
114         while(T[x].son[0]){
115             PushDown(x);
116             x=T[x].son[0];
117         }
118         Splay(x, 0);
119         return x;
120     }
121 
122     void Insert(int key){ //插入key值
123         if(!rt)
124             rt=Newnode(key, 0);
125         else{
126             int x=rt, y=0;
127             while(x){
128                 PushDown(x);
129                 y=x;
130                 if(T[x].key == key){
131                     T[x].cnt++;
132                     T[x].size++;
133                     break;
134                 }
135                 T[x].size++;
136                 x=T[x].son[key > T[x].key];
137             }
138             if(!x)
139                 x=T[y].son[key > T[y].key]=Newnode(key, y);
140             Splay(x, 0);
141         }
142     }
143 
144     void Delete(int key){ //删除值为key的节点1个
145         int x=Find(key);
146         if(!x) return;
147         if(T[x].cnt>1){
148             T[x].cnt--;
149             PushUp(x);
150             return;
151         }
152         int y=T[x].son[0];
153         while(T[y].son[1])
154             y=T[y].son[1];
155         int z=T[x].son[1];
156         while(T[z].son[0])
157             z=T[z].son[0];
158         if(!y && !z){
159             rt=0;
160             return;
161         }
162         if(!y){
163             Splay(z, 0);
164             T[z].son[0]=0;
165             PushUp(z);
166             return;
167         }
168         if(!z){
169             Splay(y, 0);
170             T[y].son[1]=0;
171             PushUp(y);
172             return;
173         }
174         Splay(y, 0);
175         Splay(z, y);
176         T[z].son[0]=0;
177         PushUp(z);
178         PushUp(y);
179     }
180 
181     int GetRank(int key){ //获得值<=key的节点个数
182         if(!Find(key)){
183             Insert(key);
184             int tmp=T[T[rt].son[0]].size;
185             Delete(key);
186             return tmp;
187         }
188         else
189             return T[T[rt].son[0]].size+T[rt].cnt;
190     }
191 
192     void Delete(int l, int r){ //删除值在[l, r]中的所有节点 l!=r
193         if(!Find(l)) Insert(l);
194         int p=Prev();
195         if(!Find(r)) Insert(r);
196         int q=Succ();
197         if(!p && !q){
198             rt=0;
199             return;
200         }
201         if(!p){
202             T[rt].son[0]=0;
203             PushUp(rt);
204             return;
205         }
206         if(!q){
207             Splay(p, 0);
208             T[rt].son[1]=0;
209             PushUp(rt);
210             return;
211         }
212         Splay(p, q);
213         T[p].son[1]=0;
214         PushUp(p);
215         PushUp(q);
216     }
217 
218     int solve(int key){
219         if(!rt) return 0;
220         int x = rt, res = 2e9;
221         while(x){
222             res = min(res,abs(T[x].key-key));
223             x = T[x].son[key > T[x].key];
224         }
225         return res;
226     }
227 }spt;
228 
229 int n,x,ans;
230 
231 int main(){
232     cin >> n >> x;
233     spt.Insert(x);
234     ans = x;
235     for(int i = 1;i < n;++i){
236         scanf("%d",&x);
237         ans += spt.solve(x);
238         spt.Insert(x);
239     }
240     cout << ans << endl;
241 
242     return 0;
243 }
View Code

  2.区间操作版。支持区间更新、区间查询、区间翻转,也可以很轻松地再添加区间切割、插入等。代码对应的题是BZOJ1251 - 序列终结者。

  1 /*bzoj 1251 序列终结者
  2   题意:
  3   给定一个长度为N的序列,每个序列的元素是一个整数。要支持以下三种操作:
  4   1. 将[L,R]这个区间内的所有数加上V;
  5   2. 将[L,R]这个区间翻转,比如1 2 3 4变成4 3 2 1;
  6   3. 求[L,R]这个区间中的最大值;
  7   最开始所有元素都是0。
  8   限制:
  9   N <= 50000, M <= 100000
 10   思路:
 11   伸展树
 12 
 13   关键点:
 14   1. 伸展树为左小右大的二叉树,所以旋转操作不会影响树的性质
 15   2. 区间操作为:
 16         int u = select(L - 1), v = select(R + 1);
 17         splay(u, 0); splay(v, u);    //通过旋转操作把询问的区间聚集到根的右子树的左子树下
 18      因为伸展树为左小右大的二叉树,旋转操作后的所以对于闭区间[L, R]之间的所有元素都聚集在根的右子树的左子树下
 19      因为闭区间[L, R],
 20      1) 所以每次都要查开区间(L - 1, R + 1),
 21      2) 所以伸展树元素1对应的标号为2,
 22      3) 所以node[0]对应空节点,node[1]对应比所以元素标号都小的点,node[2 ~ n + 1]对应元素1 ~ n,node[n + 2]对应比所有元素标号都打的点,其中node[0], node[1], node[n + 2]都是虚节点,不代表任何元素。
 23  */
 24 #include<bits/stdc++.h>
 25 using namespace std;
 26 
 27 #define LS(n) node[(n)].ch[0]
 28 #define RS(n) node[(n)].ch[1]
 29 
 30 const int N = 1e5 + 5;
 31 const int INF = 0x3f3f3f3f;
 32 struct Splay {
 33     struct Node{
 34         int fa, ch[2];
 35         bool rev;
 36         int val, lazy, mx, size;
 37         void init(int _val) {
 38             val = mx = _val;
 39             size = 1;
 40             lazy = rev = ch[0] = ch[1] = 0;
 41         }
 42     } node[N];
 43     int root;
 44 
 45     void pushup(int n) {
 46         node[n].mx = max(node[n].val, max(node[LS(n)].mx, node[RS(n)].mx));
 47         node[n].size = node[LS(n)].size + node[RS(n)].size + 1;
 48     }
 49 
 50     void pushdown(int n) {
 51         if(n == 0) return ;
 52         if(node[n].lazy) {
 53             if(LS(n)) {
 54                 node[LS(n)].val += node[n].lazy;
 55                 node[LS(n)].mx += node[n].lazy;
 56                 node[LS(n)].lazy += node[n].lazy;
 57             }
 58             if(RS(n)) {
 59                 node[RS(n)].val += node[n].lazy;
 60                 node[RS(n)].mx += node[n].lazy;
 61                 node[RS(n)].lazy += node[n].lazy;
 62             }
 63             node[n].lazy = 0;
 64         }
 65         if(node[n].rev) {
 66             if(LS(n)) node[LS(n)].rev ^= 1;
 67             if(RS(n)) node[RS(n)].rev ^= 1;
 68             swap(LS(n), RS(n));
 69             node[n].rev = 0;
 70         }
 71     }
 72 
 73     void rotate(int n, bool kind) {
 74         int fn = node[n].fa;
 75         int ffn = node[fn].fa;
 76         node[fn].ch[!kind] = node[n].ch[kind];
 77         node[node[n].ch[kind]].fa = fn;
 78 
 79         node[n].ch[kind] = fn;
 80         node[fn].fa = n;
 81 
 82         node[ffn].ch[RS(ffn) == fn] = n;
 83         node[n].fa = ffn;
 84         pushup(fn);
 85     }
 86 
 87     //旋转到goal的儿子处
 88     void splay(int n, int goal) {
 89         pushdown(n);
 90         while(node[n].fa != goal) {
 91             int fn = node[n].fa;
 92             int ffn = node[fn].fa;
 93             pushdown(ffn); pushdown(fn); pushdown(n);
 94             bool rotate_n = (LS(fn) == n);
 95             bool rotate_fn = (LS(ffn) == fn);
 96             if(ffn == goal) rotate(n, rotate_n);
 97             else {
 98                 if(rotate_n == rotate_fn) rotate(fn, rotate_fn);
 99                 else rotate(n, rotate_n);
100                 rotate(n, rotate_fn);
101             }
102         }
103         pushup(n);
104         if(goal == 0) root = n;
105     }
106 
107     int select(int pos) {
108         int u = root;
109         pushdown(u);
110         while(node[LS(u)].size != pos) {
111             if(pos < node[LS(u)].size)
112                 u = LS(u);
113             else {
114                 pos -= node[LS(u)].size + 1;
115                 u = RS(u);
116             }
117             pushdown(u);
118         }
119         return u;
120     }
121 
122     int build(int L, int R) {
123         if(L > R) return 0;
124         if(L == R) return L;
125         int mid = (L + R) >> 1;
126         int r_L, r_R;
127         LS(mid) = r_L = build(L, mid - 1);
128         RS(mid) = r_R = build(mid + 1, R);
129         node[r_L].fa = node[r_R].fa = mid;
130         pushup(mid);
131         return mid;
132     }
133 
134     void init(int n) {
135         node[0].init(-INF); node[0].size = 0;
136         node[1].init(-INF);
137         node[n + 2].init(-INF);
138         for(int i = 2; i <= n + 1; ++i)
139             node[i].init(0);
140 
141         root = build(1, n + 2);
142         node[root].fa = 0;
143 
144         node[0].fa = 0;
145         LS(0) = root;
146     }
147 
148     void solve(int type,int l,int r,int val){
149         int u = select(l-1), v = select(r+1);
150         splay(u,0);splay(v,u);
151         if(type == 1){      // Update
152             node[LS(v)].val += val;
153             node[LS(v)].mx += val;
154             node[LS(v)].lazy += val;
155         }
156         else if(type == 2)  // Reverse
157             node[LS(v)].rev ^= 1;
158         else                // Query
159             printf("%d
",node[LS(v)].mx);
160     }
161 } spt;
162 
163 int main() {
164     int n, m;
165     scanf("%d%d", &n, &m);
166     spt.init(n);
167     for(int i = 0; i < m; ++i) {
168         int op,l,r,v;
169         scanf("%d%d%d",&op,&l,&r);
170         if(op == 1) scanf("%d",&v);
171         spt.solve(op,l,r,v);
172     }
173     return 0;
174 }
View Code

BZOJ1503 - 郁闷的出纳员(插入、删除、查第k大)

思路:模板题,用第一个模板就好了。关键代码如下。

    void add(int val){
        int u = Find(-1e9), v = Find(1e9);
        Splay(u,0); Splay(v,u);
        Add[T[v].son[0]] += val;
        T[T[v].son[0]].key += val;
    }
}spt;

int n,m,cnt;

int main(){
    cin >> n >> m;
    spt.Insert(-1e9);
    spt.Insert(1e9);
    while(n--){
        char op[3];
        int k;
        scanf("%s%d",op,&k);
        if(op[0] == 'I'){
            if(k >= m) spt.Insert(k);
        }
        if(op[0] == 'A') spt.add(k);
        if(op[0] == 'S'){
            spt.add(-k);
            cnt += spt.GetRank(m-1) - 1;
            spt.Delete(-5e8,m-1);
        }
        if(op[0] == 'F'){
            int t = spt.GetKth(k+1);
            if(t < -5e8) printf("-1
");
            else printf("%d
",t);
        }
    }
    printf("%d
",cnt);

    return 0;
}
View Code

HDU3487 - Play with chain(区间翻转、切割)

题意:开始有一个1, 2, 3,... , n的序列,进行m次操作,CUT a b c将区间[a,b]取出得到新序列,将区间插入到新序列第c个元素之后,FLIP a b 将区间[a,b]翻转。输出最终的序列。

思路:对于CUT操作我们需要先提取出区间[a,b],然后删除,然后以c为边界分割为左右两部分,再合并c左边的区间和提取出来的区间[a,b],随后将最右的旋转至根(注意在第二个模板中,size是把初始化时的左右2个虚节点算进去了的),然后和c右边的区间合并。对于FLIP操作,我们可以进行lazy操作,先打个标记但不翻转,需要的时候再翻转。关键代码如下。

    void solve(int type,int l,int r,int c){
        int u = select(l-1), v = select(r+1);
        splay(u,0);splay(v,u);
        if(type == 1)       /**< Reverse */
            node[LS(v)].rev ^= 1;
        else{               /**< Cut */
            int rt1 = LS(v);
            LS(v) = 0;      // Delete [l,r]
            pushup(v);
            pushup(u);

            u = select(c);
            splay(u,0);
            int rt2 = RS(u);

            RS(u) = rt1;    // Merge [1,c] with [l,r]
            node[rt1].fa = u;
            pushup(u);

            u = select(node[root].size-1);
            splay(u,0);

            RS(u) = rt2;   // Merge
            node[rt2].fa = u;
            pushup(u);
        }
    }

    void traverse(int x){
        if(!x) return;
        pushdown(x);
        traverse(LS(x));
        ans.push_back(node[x].val);
        traverse(RS(x));
    }
View Code
原文地址:https://www.cnblogs.com/doub7e/p/7506107.html