splay 伸展树 代码实现

Splay 概念文章: http://blog.csdn.net/naivebaby/article/details/1357734

叉姐 数组实现: https://github.com/ftiasch/mithril/blob/master/2012-10-24/I.cpp#L43

Vani 指针实现: https://github.com/Azure-Vani/acm-icpc/blob/master/spoj/SEQ2.cpp

hdu 1890 写法: http://blog.csdn.net/fp_hzq/article/details/8087431

HH splay写法: http://www.notonlysuccess.com/index.php/splay-tree/

poj 3468 HH写法

View Code
  1 /*
  2 http://acm.pku.edu.cn/JudgeOnline/problem?id=3468
  3 区间跟新,区间求和
  4 */
  5 #include <cstdio>
  6 #define keyTree (ch[ ch[root][1] ][0])
  7 const int maxn = 222222;
  8 struct SplayTree{
  9     int sz[maxn];
 10     int ch[maxn][2];
 11     int pre[maxn];
 12     int root , top1 , top2;
 13     int ss[maxn] , que[maxn];
 14  
 15     inline void Rotate(int x,int f) {
 16         int y = pre[x];
 17         push_down(y);
 18         push_down(x);
 19         ch[y][!f] = ch[x][f];
 20         pre[ ch[x][f] ] = y;
 21         pre[x] = pre[y];
 22         if(pre[x]) ch[ pre[y] ][ ch[pre[y]][1] == y ] = x;
 23         ch[x][f] = y;
 24         pre[y] = x;
 25         push_up(y);
 26     }
 27     inline void Splay(int x,int goal) {
 28         push_down(x);
 29         while(pre[x] != goal) {
 30             if(pre[pre[x]] == goal) {
 31                 Rotate(x , ch[pre[x]][0] == x);
 32             } else {
 33                 int y = pre[x] , z = pre[y];
 34                 int f = (ch[z][0] == y);
 35                 if(ch[y][f] == x) {
 36                     Rotate(x , !f) , Rotate(x , f);
 37                 } else {
 38                     Rotate(y , f) , Rotate(x , f);
 39                 }
 40             }
 41         }
 42         push_up(x);
 43         if(goal == 0) root = x;
 44     }
 45     inline void RotateTo(int k,int goal) {//把第k位的数转到goal下边
 46         int x = root;
 47         push_down(x);
 48         while(sz[ ch[x][0] ] != k) {
 49             if(k < sz[ ch[x][0] ]) {
 50                 x = ch[x][0];
 51             } else {
 52                 k -= (sz[ ch[x][0] ] + 1);
 53                 x = ch[x][1];
 54             }
 55             push_down(x);
 56         }
 57         Splay(x,goal);
 58     }
 59     inline void erase(int x) {//把以x为祖先结点删掉放进内存池,回收内存
 60         int father = pre[x];
 61         int head = 0 , tail = 0;
 62         for (que[tail++] = x ; head < tail ; head ++) {
 63             ss[top2 ++] = que[head];
 64             if(ch[ que[head] ][0]) que[tail++] = ch[ que[head] ][0];
 65             if(ch[ que[head] ][1]) que[tail++] = ch[ que[head] ][1];
 66         }
 67         ch[ father ][ ch[father][1] == x ] = 0;
 68         pushup(father);
 69     }
 70     //以上一般不修改//////////////////////////////////////////////////////////////////////////////
 71     void debug() {printf("%d\n",root);Treaval(root);}
 72     void Treaval(int x) {
 73         if(x) {
 74             Treaval(ch[x][0]);
 75             printf("结点%2d:左儿子 %2d 右儿子 %2d 父结点 %2d size = %2d ,val = %2d\n",x,ch[x][0],ch[x][1],pre[x],sz[x],val[x]);
 76             Treaval(ch[x][1]);
 77         }
 78     }
 79     //以上Debug
 80  
 81  
 82     //以下是题目的特定函数:
 83     inline void NewNode(int &x,int c) {
 84         if (top2) x = ss[--top2];//用栈手动压的内存池
 85         else x = ++top1;
 86         ch[x][0] = ch[x][1] = pre[x] = 0;
 87         sz[x] = 1;
 88  
 89         val[x] = sum[x] = c;/*这是题目特定函数*/
 90         add[x] = 0;
 91     }
 92  
 93     //把延迟标记推到孩子
 94     inline void push_down(int x) {/*这是题目特定函数*/
 95         if(add[x]) {
 96             val[x] += add[x];
 97             add[ ch[x][0] ] += add[x];
 98             add[ ch[x][1] ] += add[x];
 99             sum[ ch[x][0] ] += (long long)sz[ ch[x][0] ] * add[x];
100             sum[ ch[x][1] ] += (long long)sz[ ch[x][1] ] * add[x];
101             add[x] = 0;
102         }
103     }
104     //把孩子状态更新上来
105     inline void push_up(int x) {
106         sz[x] = 1 + sz[ ch[x][0] ] + sz[ ch[x][1] ];
107         /*这是题目特定函数*/
108         sum[x] = add[x] + val[x] + sum[ ch[x][0] ] + sum[ ch[x][1] ];
109     }
110  
111     /*初始化*/
112     inline void makeTree(int &x,int l,int r,int f) {
113         if(l > r) return ;
114         int m = (l + r)>>1;
115         NewNode(x , num[m]);        /*num[m]权值改成题目所需的*/
116         makeTree(ch[x][0] , l , m - 1 , x);
117         makeTree(ch[x][1] , m + 1 , r , x);
118         pre[x] = f;
119         push_up(x);
120     }
121     inline void init(int n) {/*这是题目特定函数*/
122         ch[0][0] = ch[0][1] = pre[0] = sz[0] = 0;
123         add[0] = sum[0] = 0;
124  
125         root = top1 = 0;
126         //为了方便处理边界,加两个边界顶点
127         NewNode(root , -1);
128         NewNode(ch[root][1] , -1);
129         pre[top1] = root;
130         sz[root] = 2;
131  
132  
133         for (int i = 0 ; i < n ; i ++) scanf("%d",&num[i]);
134         makeTree(keyTree , 0 , n-1 , ch[root][1]);
135         push_up(ch[root][1]);
136         push_up(root);
137     }
138     /*更新*/
139     inline void update( ) {/*这是题目特定函数*/
140         int l , r , c;
141         scanf("%d%d%d",&l,&r,&c);
142         RotateTo(l-1,0);
143         RotateTo(r+1,root);
144         add[ keyTree ] += c;
145         sum[ keyTree ] += (long long)c * sz[ keyTree ];
146     }
147     /*询问*/
148     inline void query() {/*这是题目特定函数*/
149         int l , r;
150         scanf("%d%d",&l,&r);
151         RotateTo(l-1 , 0);
152         RotateTo(r+1 , root);
153         printf("%lld\n",sum[keyTree]);
154     }
155  
156  
157     /*这是题目特定变量*/
158     int num[maxn];
159     int val[maxn];
160     int add[maxn];
161     long long sum[maxn];
162 }spt;
163  
164  
165 int main() {
166     int n , m;
167     scanf("%d%d",&n,&m);
168     spt.init(n);
169     while(m --) {
170         char op[2];
171         scanf("%s",op);
172         if(op[0] == 'Q') {
173             spt.query();
174         } else {
175             spt.update();
176         }
177     }
178     return 0;
179 }

叉姐 

View Code
  1 #include <cstdio>
  2 #include <cstring>
  3 #include <vector>
  4 #include <climits>
  5 #include <algorithm>
  6 using namespace std;
  7 
  8 const int N = 200000;
  9 const int M = 1 + (N << 1);
 10 const int EMPTY = M - 1;
 11 
 12 const int MOD = 99990001;
 13 
 14 int nodeCount, type[M], parent[M], children[M][2], id[M];
 15 
 16 int scale[M], delta[M], weight[M], size[M], minimum[M];
 17 
 18 void update(int x) {
 19     size[x] = size[children[x][0]] + 1 + size[children[x][1]];
 20     minimum[x] = min(min(minimum[children[x][0]], minimum[children[x][1]]), id[x]);
 21 }
 22 
 23 void modify(int x, int k, int b) {
 24     weight[x] = ((long long)k * weight[x] + b) % MOD;
 25     scale[x] = (long long)k * scale[x] % MOD;
 26     delta[x] = ((long long)k * delta[x] + b) % MOD;
 27 }
 28 
 29 void pushDown(int x) {
 30     for (int i = 0; i < 2; ++ i) {
 31         if (children[x][i] != EMPTY) {
 32             modify(children[x][i], scale[x], delta[x]);
 33         }
 34     }
 35     scale[x] = 1;
 36     delta[x] = 0;
 37 }
 38 
 39 void rotate(int x) {
 40     int t = type[x];
 41     int y = parent[x];
 42     int z = children[x][1 ^ t];
 43     type[x] = type[y];
 44     parent[x] = parent[y];
 45     if (type[x] != 2) {
 46         children[parent[x]][type[x]] = x;
 47     }
 48     type[y] = 1 ^ t;
 49     parent[y] = x;
 50     children[x][1 ^ t] = y;
 51     if (z != EMPTY) {
 52         type[z] = t;
 53         parent[z] = y;
 54     }
 55     children[y][t] = z;
 56     update(y);
 57 }
 58 
 59 void splay(int x) {
 60     if (x == EMPTY) {
 61         return;
 62     }
 63     vector <int> stack(1, x);
 64     for (int i = x; type[i] != 2; i = parent[i]) {
 65         stack.push_back(parent[i]);
 66     }
 67     while (!stack.empty()) {
 68         pushDown(stack.back());
 69         stack.pop_back();
 70     }
 71     while (type[x] != 2) {
 72         int y = parent[x];
 73         if (type[x] == type[y]) {
 74             rotate(y);
 75         } else {
 76             rotate(x);
 77         }
 78         if (type[x] == 2) {
 79             break;
 80         }
 81         rotate(x);
 82     }
 83     update(x);
 84 }
 85 
 86 int goLeft(int x) {
 87     while (children[x][0] != EMPTY) {
 88         x = children[x][0];
 89     }
 90     return x;
 91 }
 92 
 93 int join(int x, int y) {
 94     if (x == EMPTY || y == EMPTY) {
 95         return x != EMPTY ? x : y;
 96     }
 97     y = goLeft(y);
 98     splay(y);
 99     splay(x);
100     type[x] = 0;
101     parent[x] = y;
102     children[y][0] = x;
103     update(y);
104     return y;
105 }
106 
107 pair <int, int> split(int x) {
108     splay(x);
109     int a = children[x][0];
110     int b = children[x][1];
111     children[x][0] = children[x][1] = EMPTY;
112     if (a != EMPTY) {
113         type[a] = 2;
114         parent[a] = EMPTY;
115     }
116     if (b != EMPTY) {
117         type[b] = 2;
118         parent[b] = EMPTY;
119     }
120     return make_pair(a, b);
121 }
122 
123 int newNode(int init, int vid) {
124     int x = nodeCount ++;
125     type[x] = 2;
126     parent[x] = children[x][0] = children[x][1] = EMPTY;
127     id[x] = vid;
128     weight[x] = init;
129     scale[x] = 1;
130     delta[x] = 0;
131     update(x);
132     return x;
133 }
134 
135 int n;
136 int edgeCount, firstEdge[N], to[M], nextEdge[M], initWeight[N], position[M];
137 
138 int root;
139 
140 void addEdge(int u, int v) {
141     to[edgeCount] = v;
142     nextEdge[edgeCount] = firstEdge[u];
143     firstEdge[u] = edgeCount ++;
144 }
145 
146 void dfs(int p, int u) {
147     for (int iter = firstEdge[u]; iter != -1; iter = nextEdge[iter]) {
148         int v = to[iter];
149         if (v != p) {
150             position[iter] = nodeCount;
151             root = join(root, newNode(initWeight[iter >> 1], min(u, v)));
152             dfs(u, v);
153             position[iter ^ 1] = nodeCount;
154             root = join(root, newNode(initWeight[iter >> 1], min(u, v)));
155         }
156     }
157 }
158 
159 int getRank(int x) { // 1-based
160     splay(x);
161     return size[children[x][0]] + 1;
162 }
163 
164 void print(int root) {
165     if (root != EMPTY) {
166         printf("[ ");
167         print(children[root][0]);
168         printf(" %d ", root);
169         print(children[root][1]);
170         printf(" ]");
171     }
172 }
173 
174 int main() {
175     size[EMPTY] = 0;
176     minimum[EMPTY] = INT_MAX;
177     parent[EMPTY] = 2;
178     scanf("%d", &n);
179     edgeCount = 0;
180     memset(firstEdge, -1, sizeof(firstEdge));
181     for (int i = 0; i < n - 1; ++ i) {
182         int a, b;
183         scanf("%d%d%d", &a, &b, initWeight + i);
184         a --;
185         b --;
186         addEdge(a, b);
187         addEdge(b, a);
188     }
189     nodeCount = 0;
190     root = EMPTY;
191     dfs(-1, 0);
192     for (int i = 0; i < n - 1; ++ i) {
193         int id;
194         scanf("%d", &id);
195         id --;
196 
197         int a = position[id << 1];
198         int b = position[(id << 1) ^ 1];
199         if (getRank(a) > getRank(b)) {
200             swap(a, b);
201         }
202         splay(a);
203 
204         int output = weight[a];
205         printf("%d\n", output);
206         fflush(stdout);
207 
208         pair <int, int> ret1 = split(a);
209         pair <int, int> ret2 = split(b);
210         int x = ret1.first;
211         int y = ret2.first;
212         int z = ret2.second;
213         x = join(z, x);
214         splay(x);
215         splay(y);
216         if (size[x] > size[y]) {
217             swap(x, y);
218         }
219         if (size[x] == size[y] && minimum[x] > minimum[y]) {
220             swap(x, y);
221         }
222         modify(x, output, 0);
223         modify(y, 1, output);
224     }
225     return 0;
226 }

spoj SEQ2

Vani 

View Code
  1 #include <cstdio>
  2 #include <cctype>
  3 #include <algorithm>
  4 #include <cstring>
  5 
  6 using namespace std;
  7 
  8 namespace Solve {
  9     const int MAXN = 500010;
 10     const int inf = 500000000;
 11 
 12     char BUF[50000000], *pos = BUF;
 13     inline int ScanInt(void) {
 14         int r = 0, d = 0;
 15         while (!isdigit(*pos) && *pos != '-') pos++;
 16         if (*pos != '-') r = *pos - 48; else d = 1; pos++;
 17         while ( isdigit(*pos)) r = r * 10 + *pos++ - 48;
 18         return d ? -r : r;
 19     }
 20     inline void ScanStr(char *st) {
 21         int l = 0;
 22         while (!(isupper(*pos) || *pos == '-')) pos++;
 23         st[l++] = *pos++;
 24         while (isupper(*pos) || *pos == '-') st[l++] = *pos++; st[l] = 0;
 25     }
 26 
 27     struct Node {
 28         Node *ch[2], *p;
 29         int v, lmax, rmax, m, same, rev, sum, size;
 30         inline bool dir(void) {return this == p->ch[1];}
 31         inline void SetC(Node *x, bool d) {ch[d] = x, x->p = this;}
 32         inline void Update(void) {
 33             Node *L = ch[0], *R = ch[1];
 34             size = L->size + R->size + 1;
 35             m = max(L->m, R->m);
 36             m = max(m, L->rmax + v + R->lmax);
 37             lmax = max(L->lmax, L->sum + v + R->lmax);
 38             rmax = max(R->rmax, R->sum + v + L->rmax);
 39             sum = L->sum + R->sum + v;
 40         }
 41         inline void Rev(void) {
 42             if (v == -inf) return;
 43             rev ^= 1;
 44             swap(ch[0], ch[1]);
 45             swap(lmax, rmax);
 46         }
 47         inline void Same(int u) {
 48             if (v == -inf) return;
 49             same = u;
 50             sum = u * size;
 51             if (sum > 0) lmax = rmax = m = sum; else lmax = 0, rmax = 0, m = u;
 52             v = u;
 53         }
 54         inline void Down(void) {
 55             if (rev) {
 56                 ch[0]->Rev(), ch[1]->Rev();
 57                 rev = 0;
 58             }
 59             if (same != -inf) {
 60                 ch[0]->Same(same), ch[1]->Same(same);
 61                 same = -inf;
 62             }
 63         }
 64     } Tnull, *null = &Tnull;
 65 
 66     class Splay {public:
 67         Node *root;
 68         inline void rotate(Node *x) {
 69             Node *p = x->p; bool d = x->dir();
 70             p->Down(); x->Down();
 71             p->p->SetC(x, p->dir());
 72             p->SetC(x->ch[!d], d);
 73             x->SetC(p, !d);
 74             p->Update();
 75         }
 76         inline void splay(Node *x, Node *G) {
 77             if (G == null) root = x;
 78             while (x->p != G) {
 79                 if (x->p->p == G) {rotate(x); break;}
 80                 else {if (x->dir() == x->p->dir()) rotate(x->p), rotate(x); else rotate(x), rotate(x);}
 81             }
 82             x->Update();
 83         }
 84         inline Node *Select(int k) {
 85             Node *t = root;
 86             while (t->Down(), t->ch[0]->size + 1 != k) {
 87                 if (k > t->ch[0]->size + 1) k -= t->ch[0]->size + 1, t = t->ch[1];
 88                 else t = t->ch[0];
 89             }
 90             splay(t, null);
 91             return t;
 92         }
 93         inline Node *getInterval(int l, int r) {
 94             Node *L = Select(l), *R = Select(r + 2);
 95             splay(L, null); splay(R, L);
 96             L->Down(); R->Down();
 97             return R;
 98         }
 99         inline void Insert(int pos, Node *x) {
100             Node *now = getInterval(pos + 1, pos);
101             now->SetC(x, 0);
102             now->Update(); root->Update();
103         }
104         inline void Delete(int l, int r) {
105             Node *now = getInterval(l, r);
106             now->ch[0] = null;
107             now->Update(); root->Update();
108         }
109         inline void Make(int l, int r, int c) {
110             Node *now = getInterval(l, r);
111             now->ch[0]->Same(c);
112             now->Update(); root->Update();
113         }
114         inline void Reverse(int l, int r) {
115             Node *now = getInterval(l, r);
116             now->ch[0]->Rev();
117             now->Update(); root->Update();
118         }
119         inline int Sum(int l, int r) {
120             Node *now = getInterval(l, r);
121             root->Down(); now->Down();
122             return now->ch[0]->sum;
123         }
124         inline int maxSum(int l, int r) {
125             Node *now = getInterval(l, r);
126             root->Down(); now->Down();
127             return now->ch[0]->m;
128         }
129         inline Node* Renew(int c) {
130             Node *ret = new Node;
131             ret->ch[0] = ret->ch[1] = ret->p = null; ret->size = 1;
132             ret->Same(c); ret->same = -inf;
133             return ret;
134         }
135         inline Node* Build(int l, int r, int *a) {
136             if (l > r) return null;
137             int mid = (l + r) >> 1;
138             Node *ret = Renew(a[mid]);
139             ret->ch[0] = Build(l, mid - 1, a);
140             ret->ch[1] = Build(mid + 1, r, a);
141             ret->ch[0]->p = ret->ch[1]->p = ret;
142             ret->Update();
143             return ret;
144         }
145         inline void P(Node *t) {
146             if (t == null) return;
147             t->Down(); t->Update();
148             P(t->ch[0]);
149             printf("%d ", t->v);
150             P(t->ch[1]);
151         }
152     }T;
153 
154 
155     int a[MAXN]; char ch[10];
156 
157     inline void solve(void) {
158         fread(BUF, 1, 50000000, stdin);
159         null->same = null->m = null->v = -inf;
160         int kase = ScanInt();
161         while (kase--) {
162             int n = ScanInt(), m = ScanInt();
163             for (int i = 1; i <= n; i++) a[i] = ScanInt();
164             T.root = T.Build(0, n + 1, a);
165             for (int i = 1; i <= m; i++) {
166                 ScanStr(ch);
167                 if (strcmp(ch, "INSERT") == 0) {
168                     int pos = ScanInt(), t = ScanInt();
169                     for (int j = 1; j <= t; j++) a[j] = ScanInt();
170                     Node *tmp = T.Build(1, t, a);
171                     T.Insert(pos, tmp);
172                 }
173                 if (strcmp(ch, "DELETE") == 0) {
174                     int l = ScanInt(), r = ScanInt(); r = l + r - 1;
175                     T.Delete(l, r);
176                 }
177                 if (strcmp(ch, "MAKE-SAME") == 0) {
178                     int l = ScanInt(), r = ScanInt(), c = ScanInt(); r = l + r - 1;
179                     T.Make(l, r, c);
180                 }
181                 if (strcmp(ch, "REVERSE") == 0) {
182                     int l = ScanInt(), r = ScanInt(); r = l + r - 1;
183                     T.Reverse(l, r);
184                 }
185                 if (strcmp(ch, "GET-SUM") == 0) {
186                     int l = ScanInt(), r = ScanInt(); r = l + r - 1;
187                     int ret = T.Sum(l, r);
188                     printf("%d\n", ret);
189                 }
190                 if (strcmp(ch, "MAX-SUM") == 0) {
191                     int ret = T.maxSum(1, T.root->size - 2);
192                     printf("%d\n", ret);
193                 }
194             }
195         }
196     }
197 }
198 
199 int main(void) {
200     freopen("in", "r", stdin);
201     Solve::solve();
202     return 0;
203 }
原文地址:https://www.cnblogs.com/yefeng1627/p/3006308.html