HDU 4578 线段树各种区间操作

HDU 4578

原题链接

【题意】:初始一个长度为n的数组全为0,有m个操作,输入op, l, r, x。

  1. op = 1时,把 [l, r] 中的所有数加上x
  2. op = 2时, 把 [l, r] 中的所有树乘上x
  3. op = 3时, 把[l, r]中的所有数全置为x
  4. op = 4时, 输出 [l, r] 中所有数的 x 方的和

【思路】:令(x = a * x + b),即对一个x,令他乘以a, 再加上b,观察各个和的情况:

[sum1=∑x⇒∑(a×x+b)=a*sum1+b*length ]

[sum2=∑x^2⇒∑(a×x+b)^2=a^2 *sum2+2*a*b*sum1+b^2*ength ]

[sum3=∑x^3⇒∑(a×x+b)^3=a^3*sum3+3*a^2*b*sum2+3*a*b^2*sum1+b^3×length ]

只执行加法时:a = 1, b = x; 只执行乘法时, a = x, b = 0; 一起执行时,(a = t[cur].mul, b = t[cur].add).

在置数时直接将add和mul懒标记恢复到初始值即可。

#include <bits/stdc++.h>
#define debug(x) cout << #x << " = " << x << endl; 
#define ls cur<<1
#define rs cur<<1|1

using namespace std;
typedef long long LL;
const int maxn = 2e5 + 10;
const int inf = 0x3f3f3f3f;
const int mod = 10007;
const double eps = 1e-6, pi = acos(-1.0);

struct Tree{
    int l, r, len;
    LL sum1, sum2, sum3;
    LL add, mul, st;
}t[maxn<<1];
int n, m;

inline void pushup(int cur){
    t[cur].sum1 = (t[ls].sum1 + t[rs].sum1) % mod;
    t[cur].sum2 = (t[ls].sum2 + t[rs].sum2) % mod;
    t[cur].sum3 = (t[ls].sum3 + t[rs].sum3) % mod;
}
inline void pushdown(int cur){
    if(t[cur].st){		//只有置数
        t[ls].mul = t[rs].mul = 1;
        t[ls].add = t[rs].add = 0;
        t[ls].st = t[rs].st = t[cur].st;
        
        t[ls].sum3 = t[cur].st * t[cur].st % mod * t[cur].st % mod * t[ls].len % mod;
        t[rs].sum3 = t[cur].st * t[cur].st % mod * t[cur].st % mod * t[rs].len % mod;

        t[ls].sum2 = t[cur].st * t[cur].st % mod * t[ls].len % mod;
        t[rs].sum2 = t[cur].st * t[cur].st % mod * t[rs].len % mod;

        t[ls].sum1 = t[cur].st * t[ls].len % mod;
        t[rs].sum1 = t[cur].st * t[rs].len % mod;

        t[cur].st = 0;
    }
    if(t[cur].add != 0 || t[cur].mul != 1){	//加和乘都有
        LL a = t[cur].mul, b = t[cur].add;

        t[ls].mul = t[cur].mul * t[ls].mul % mod;
        t[rs].mul = t[cur].mul * t[rs].mul % mod;

        t[ls].add = (t[ls].add * t[cur].mul % mod + t[cur].add) % mod;
        t[rs].add = (t[rs].add * t[cur].mul % mod + t[cur].add) % mod;

        t[ls].sum3 = (a * a % mod * a % mod * t[ls].sum3 % mod + 3 * a * a % mod * b % mod * t[ls].sum2 % mod + 3 * a * b % mod * b % mod * t[ls].sum1 % mod + b * b % mod * b % mod * t[ls].len % mod) % mod;
        t[rs].sum3 = (a * a % mod * a % mod * t[rs].sum3 % mod + 3 * a * a % mod * b % mod * t[rs].sum2 % mod + 3 * a * b % mod * b % mod * t[rs].sum1 % mod + b * b % mod * b % mod * t[rs].len % mod) % mod;

        t[ls].sum2 = (a * a % mod * t[ls].sum2 % mod + 2 * a * b % mod * t[ls].sum1 + b * b % mod * t[ls].len % mod) % mod;
        t[rs].sum2 = (a * a % mod * t[rs].sum2 % mod + 2 * a * b % mod * t[rs].sum1 + b * b % mod * t[rs].len % mod) % mod;

        t[ls].sum1 = (a * t[ls].sum1 % mod + b * t[ls].len % mod) % mod;
        t[rs].sum1 = (a * t[rs].sum1 % mod + b * t[rs].len % mod) % mod;

        t[cur].add = 0; t[cur].mul = 1;
    }
}
void build(int l, int r, int cur){
    t[cur].l = l; t[cur].r = r; t[cur].len = r - l + 1;
    t[cur].sum1 = t[cur].sum2 = t[cur].sum3 = 0;
    t[cur].add = t[cur].st = 0; t[cur].mul = 1;
    if(l == r)  return;
    int mid = l + r >> 1;
    build(l, mid, ls);
    build(mid + 1, r, rs);
    pushup(cur);
}
void change(int l, int r, int op, int x, int cur){
    if(l <= t[cur].l && t[cur].r <= r){
        LL tmp = x * x % mod * x % mod;
        if(op == 1){	//只加
            t[cur].add = (t[cur].add + x) % mod;
            t[cur].sum3 = (t[cur].sum3 + 3 * x * t[cur].sum2 % mod + 3 * x * x * t[cur].sum1 + tmp * t[cur].len % mod) % mod;
            t[cur].sum2 = (t[cur].sum2 + 2 * x * t[cur].sum1 % mod + x * x % mod * t[cur].len);
            t[cur].sum1 = (t[cur].sum1 + x * t[cur].len) % mod;
        }	
        else if(op == 2){	//只乘
            t[cur].add = (t[cur].add * x) % mod;
            t[cur].mul = (t[cur].mul * x) % mod;
            t[cur].sum3 = tmp * t[cur].sum3 % mod;
            t[cur].sum2 = x * x % mod * t[cur].sum2 % mod;
            t[cur].sum1 = x * t[cur].sum1 % mod;
        }	
        else if(op == 3){	//只置数
            t[cur].mul = 1;
            t[cur].add = 0;
            t[cur].st = x;
            t[cur].sum3 = tmp * t[cur].len % mod;
            t[cur].sum2 = x * x % mod * t[cur].len % mod;
            t[cur].sum1 = x * t[cur].len % mod;
        }
        return;
    }
    int mid = t[cur].l + t[cur].r >> 1;
    pushdown(cur);
    if(l <= mid)    change(l, r, op, x, ls);
    if(mid < r) change(l, r, op, x, rs);
    pushup(cur);
}
LL query(int l, int r, int x, int cur){
    if(l <= t[cur].l && t[cur].r <= r){
        if(x == 1) return t[cur].sum1;
        if(x == 2) return t[cur].sum2;
        if(x == 3) return t[cur].sum3;
    }
    int mid = t[cur].l + t[cur].r >> 1;
    pushdown(cur);
    LL ans = 0;
    if(l <= mid)    ans = (ans + query(l, r, x, ls)) % mod;
    if(mid < r) ans = (ans + query(l, r, x, rs)) % mod;
    return ans;
}

int main()
{
    while(~scanf("%d %d", &n, &m), n + m){
        build(1, n, 1);
        while(m--){
            int l, r, op, x;
            scanf("%d %d %d %d", &op, &l, &r, &x);
            if(op != 4) change(l, r, op, x % mod, 1);
            else printf("%lld
", query(l, r, x, 1));
        }
    }
    getchar(); getchar();
}
原文地址:https://www.cnblogs.com/StungYep/p/12828203.html