[HDOJ4578]Transformation(线段树,多延迟标记)

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4578

四种操作:查询、加法、乘法、改数。应该是需要维护三个lazy标记,然后就是套路了。查询是区间内所有的数的p次幂然后再求和,这个p只有三个值(1,2,3),直接维护三棵线段树,分别是1 2 3次幂。

注意延迟标记的时候,如果有改数,那之前的加法和乘法就可以不用做了。在更新乘法的时候,如果有加法存在,那加法的标记应该更新,乘一下乘法的数,因为(a+b)*c = a*c+b*c,父亲是a+b,儿子是a和b。

pushdown的顺序就是:改数、乘法、加法。三种询问好处理,(a+b)^2和(a+b)^3展开就行了。

这是我做过的比较复杂的线段树了。

  1 /* 
  2 ┓┏┓┏┓┃キリキリ♂ mind!
  3 ┛┗┛┗┛┃\○/
  4 ┓┏┓┏┓┃ /
  5 ┛┗┛┗┛┃ノ)
  6 ┓┏┓┏┓┃
  7 ┛┗┛┗┛┃
  8 ┓┏┓┏┓┃
  9 ┛┗┛┗┛┃
 10 ┓┏┓┏┓┃
 11 ┛┗┛┗┛┃
 12 ┓┏┓┏┓┃
 13 ┃┃┃┃┃┃
 14 ┻┻┻┻┻┻
 15 */
 16 #include <algorithm>
 17 #include <iostream>
 18 #include <iomanip>
 19 #include <cstring>
 20 #include <climits>
 21 #include <complex>
 22 #include <fstream>
 23 #include <cassert>
 24 #include <cstdio>
 25 #include <bitset>
 26 #include <vector>
 27 #include <deque>
 28 #include <queue>
 29 #include <stack>
 30 #include <ctime>
 31 #include <set>
 32 #include <map>
 33 #include <cmath>
 34 using namespace std;
 35 #define fr first
 36 #define sc second
 37 #define cl clear
 38 #define BUG puts("here!!!")
 39 #define W(a) while(a--)
 40 #define pb(a) push_back(a)
 41 #define Rlf(a) scanf("%lf", &a);
 42 #define Rint(a) scanf("%d", &a)
 43 #define Rll(a) scanf("%I64d", &a)
 44 #define Rs(a) scanf("%s", a)
 45 #define Cin(a) cin >> a
 46 #define FRead() freopen("in", "r", stdin)
 47 #define FWrite() freopen("out", "w", stdout)
 48 #define Rep(i, len) for(int i = 0; i < (len); i++)
 49 #define For(i, a, len) for(int i = (a); i < (len); i++)
 50 #define Cls(a) memset((a), 0, sizeof(a))
 51 #define Clr(a, x) memset((a), (x), sizeof(a))
 52 #define Full(a) memset((a), 0x7f7f, sizeof(a))
 53 #define lrt rt << 1
 54 #define rrt rt << 1 | 1
 55 #define pi 3.14159265359
 56 #define RT return
 57 #define lowbit(x) x & (-x)
 58 #define onenum(x) __builtin_popcount(x)
 59 typedef long long LL;
 60 typedef long double LD;
 61 typedef unsigned long long ULL;
 62 typedef pair<int, int> pii;
 63 typedef pair<string, int> psi;
 64 typedef map<string, int> msi;
 65 typedef vector<int> vi;
 66 typedef vector<LL> vl;
 67 typedef vector<vl> vvl;
 68 typedef vector<bool> vb;
 69 
 70 inline bool scan_d(int &num) {
 71     char in;bool IsN=false;
 72     in=getchar();
 73     if(in==EOF) return false;
 74     while(in!='-'&&(in<'0'||in>'9')) in=getchar();
 75     if(in=='-'){ IsN=true;num=0;}
 76     else num=in-'0';
 77     while(in=getchar(),in>='0'&&in<='9'){
 78             num*=10,num+=in-'0';
 79     }
 80     if(IsN) num=-num;
 81     return true;
 82 }
 83 
 84 const int mod = 10007;
 85 const int maxn = 100010;
 86 
 87 LL add[maxn<<2], put[maxn<<2], mul[maxn<<2];
 88 LL sum1[maxn<<2], sum2[maxn<<2], sum3[maxn<<2];
 89 
 90 void pushUP(int rt) {
 91     sum1[rt] = (sum1[lrt] + sum1[rrt]) % mod;
 92     sum2[rt] = (sum2[lrt] + sum2[rrt]) % mod;
 93     sum3[rt] = (sum3[lrt] + sum3[rrt]) % mod;
 94 }
 95 
 96 void pushDOWN(int rt, int m) {
 97     if(put[rt]) {
 98         put[lrt] = put[rrt] = put[rt];
 99         add[lrt] = add[rrt] = 0;
100         mul[lrt] = mul[rrt] = 1;
101         sum1[lrt] = (m - (m >> 1)) % mod * put[rt] % mod;
102         sum1[rrt] = (m >> 1) % mod * put[rt] % mod;
103         sum2[lrt] = (m - (m >> 1)) % mod * put[rt] % mod * put[rt] % mod;
104         sum2[rrt] = (m >> 1) % mod * put[rt] % mod * put[rt] % mod;
105         sum3[lrt] = (m - (m >> 1) % mod) * (put[rt] * put[rt]) % mod * put[rt] % mod % mod;
106         sum3[rrt] = (m >> 1) % mod * put[rt] * put[rt] % mod * put[rt] % mod % mod;
107         put[rt] = 0;
108     }
109     if(mul[rt] != 1) {
110         mul[lrt] = mul[lrt] * mul[rt] % mod;
111         mul[rrt] = mul[rrt] * mul[rt] % mod;
112         if(add[lrt]) add[lrt] = (add[lrt] * mul[rt]) % mod;
113         if(add[rrt]) add[rrt] = (add[rrt] * mul[rt]) % mod;
114         sum1[lrt] = (sum1[lrt] * mul[rt]) % mod;
115         sum1[rrt] = (sum1[rrt] * mul[rt]) % mod;
116         sum2[lrt] = (sum2[lrt] * mul[rt]) % mod * mul[rt] % mod;
117         sum2[rrt] = (sum2[rrt] * mul[rt]) % mod * mul[rt] % mod;
118         sum3[lrt] = (sum3[lrt] * mul[rt]) % mod * mul[rt] % mod * mul[rt] % mod;
119         sum3[rrt] = (sum3[rrt] * mul[rt]) % mod * mul[rt] % mod * mul[rt] % mod;
120         mul[rt] = 1;
121     }
122     if(add[rt]) {
123         add[lrt] = (add[lrt] + add[rt]) % mod;
124         add[rrt] = (add[rrt] + add[rt]) % mod;
125         sum3[lrt] = (sum3[lrt] + ((add[rt] * add[rt] % mod) * add[rt] % mod * (m - (m >> 1)) % mod) + 3 * add[rt] * ((sum2[lrt] + sum1[lrt] * add[rt]) % mod)) % mod;
126         sum3[rrt] = (sum3[rrt] + ((add[rt] * add[rt] % mod) * add[rt] % mod * (m >> 1) % mod) + 3 * add[rt] * ((sum2[rrt] + sum1[rrt] * add[rt]) % mod)) % mod;
127         sum2[lrt] = (sum2[lrt] + ((add[rt] * add[rt] % mod) * (m - (m >> 1)) % mod) + (2 * sum1[lrt] * add[rt] % mod)) % mod;
128         sum2[rrt] = (sum2[rrt] + (((add[rt] * add[rt] % mod) * (m >> 1)) % mod) + (2 * sum1[rrt] * add[rt] % mod)) % mod;
129         sum1[lrt] = (sum1[lrt] + (m - (m >> 1)) * add[rt]) % mod;
130         sum1[rrt] = (sum1[rrt] + (m >> 1) * add[rt]) % mod;
131         add[rt] = 0;
132     }
133 }
134 
135 void build(int l, int r, int rt) {
136     add[rt] = put[rt] = 0; mul[rt] = 1;
137     sum1[rt] =sum2[rt] = sum3[rt] = 0;
138     if(l == r) return;
139     int m = (l + r) >> 1;
140     build(l, m, lrt);
141     build(m+1, r, rrt);
142 }
143 
144 void update(int L, int R, int c, int ch, int l, int r, int rt) {
145     if(L <= l && r <= R) {
146         if(ch == 3) {
147             put[rt] = c;
148             add[rt] = 0;
149             mul[rt] = 1;
150             sum1[rt] = ((r - l + 1) * c) % mod;
151             sum2[rt] = (((r - l + 1) * c) % mod * c) % mod;
152             sum3[rt] = ((((r - l + 1) * c) % mod * c) % mod * c) % mod;
153         }
154         if(ch == 2) {
155             mul[rt] = (mul[rt] * c) % mod;
156             if(add[rt]) add[rt] = (add[rt] * c) % mod;
157             sum1[rt] = (sum1[rt] * c) % mod;
158             sum2[rt] = ((sum2[rt] * c) % mod * c) % mod;
159             sum3[rt] = (((sum3[rt] * c) % mod * c) % mod * c) % mod;
160         }
161         if(ch == 1) {
162             add[rt] += c;
163             sum3[rt] = (sum3[rt] + (((c * c) % mod * c) % mod * (r - l + 1)) % mod + 3 * c * ((sum2[rt] + sum1[rt] * c) % mod)) % mod;
164             sum2[rt] = (sum2[rt] + (c * c % mod * (r - l + 1) % mod) + 2 * sum1[rt] * c) % mod;
165             sum1[rt] = (sum1[rt] + (r - l + 1) * c) % mod;
166         }
167         return;
168     }
169     pushDOWN(rt, r-l+1);
170     int m = (l + r) >> 1;
171     if(R <= m) update(L, R, c, ch, l, m, lrt);
172     else if(L > m) update(L, R, c, ch, m+1, r, rrt);
173     else {
174         update(L, R, c, ch, l, m, lrt);
175         update(L, R, c, ch, m+1, r, rrt);
176     }
177     pushUP(rt);
178 }
179 
180 LL query(int L, int R, int p, int l, int r, int rt) {
181     if(L <= l && r <= R) {
182         if(p == 1) return sum1[rt] % mod;
183         if(p == 2) return sum2[rt] % mod;
184         if(p == 3) return sum3[rt] % mod;
185     }
186     pushDOWN(rt, r-l+1);
187     int m = (l + r) >> 1;
188     if(R <= m) return query(L, R, p, l, m, lrt);
189     else if(m < L) return query(L, R, p, m+1, r, rrt);
190     else return (query(L, R, p, l, m, lrt) + query(L, R, p, m+1, r, rrt)) % mod;
191 }
192 
193 
194 int n, m;
195 int a, b, c, ch;
196 
197 int main() {
198     // FRead();
199     while(~scan_d(n) && ~scan_d(m) && n + m) {
200         build(1, n, 1);
201         W(m) {
202             scan_d(ch); scan_d(a); scan_d(b); scan_d(c);
203             if(ch != 4) update(a, b, c, ch, 1, n, 1);
204             else cout << query(a, b, c, 1, n, 1) << endl;
205         }
206     }
207     RT 0;
208 }
原文地址:https://www.cnblogs.com/kirai/p/5558720.html