Hdu 4578 Transformation (线段树 分类分析)

Transformation

Time Limit: 15000/8000 MS (Java/Others)    Memory Limit: 65535/65536 K (Java/Others)
Total Submission(s): 10082    Accepted Submission(s): 2609


Problem Description
Yuanfang is puzzled with the question below: 
There are n integers, a1, a2, …, an. The initial values of them are 0. There are four kinds of operations.
Operation 1: Add c to each number between ax and ay inclusive. In other words, do transformation ak<---ak+c, k = x,x+1,…,y.
Operation 2: Multiply c to each number between ax and ay inclusive. In other words, do transformation ak<---ak×c, k = x,x+1,…,y.
Operation 3: Change the numbers between ax and ay to c, inclusive. In other words, do transformation ak<---c, k = x,x+1,…,y.
Operation 4: Get the sum of p power among the numbers between ax and ay inclusive. In other words, get the result of axp+ax+1p+…+ay p.
Yuanfang has no idea of how to do it. So he wants to ask you to help him. 
 
Input
There are no more than 10 test cases.
For each case, the first line contains two numbers n and m, meaning that there are n integers and m operations. 1 <= n, m <= 100,000.
Each the following m lines contains an operation. Operation 1 to 3 is in this format: "1 x y c" or "2 x y c" or "3 x y c". Operation 4 is in this format: "4 x y p". (1 <= x <= y <= n, 1 <= c <= 10,000, 1 <= p <= 3)
The input ends with 0 0.
 
Output
For each operation 4, output a single integer in one line representing the result. The answer may be quite large. You just need to calculate the remainder of the answer when divided by 10007.
 
Sample Input
5 5 3 3 5 7 1 2 4 4 4 1 5 2 2 2 5 8 4 3 5 3 0 0
 
Sample Output
307 7489
 
Source
 

题解

这道题有三种询问:set , add , mul。所以lazy标记要有三个,如果三个标记同时出现的处理方法——当更新set操作时,就把add标记和mul标记全部取消;当更新mul操作时,如果当前节点add标记存在,就把add标记改为:add * mul。这样的话就可以在PushDown()操作中先执行set,然后mul,最后add。

C++代码一

#include<cstdio>
#include<cstring>
#include<iostream>
using namespace std;
const int maxn = 100010 ;
#define left v<<1
#define right v<<1|1
#define mod  10007
struct node
{
    int l ,r , value ;
    int eq , add , mul ;
}tree[maxn<<2];
void build(int l , int r  , int v)
{
    tree[v].l = l ;
    tree[v].r = r ;
    tree[v].add = 0 ; tree[v].mul = 1 ;tree[v].eq = -1 ;
    if(l == r)
    {tree[v].eq = 0 ; return  ;}
    int mid = (l + r) >> 1 ;
    build(l , mid , left) ;
    build(mid + 1 , r , right) ;
}
void push_down(int v)
{
    if(tree[v].l == tree[v].r)return ;
    if(tree[v].eq != -1)
    {
        tree[left].eq = tree[right].eq = tree[v].eq ;
        tree[left].add = tree[right].add = 0 ;
        tree[left].mul = tree[right].mul = 1;
        tree[v].eq = -1;
        return  ;
    }
    if(tree[v].mul != 1)
    {
        if(tree[left].eq != -1)
        tree[left].eq = (tree[left].eq*tree[v].mul)%mod ;
        else
        {
            push_down(left) ;
            tree[left].mul = (tree[left].mul*tree[v].mul)%mod ;
        }
        if(tree[right].eq != -1)
        tree[right].eq = (tree[right].eq*tree[v].mul)%mod ;
        else
        {
            push_down(right) ;
            tree[right].mul = (tree[right].mul*tree[v].mul)%mod ;
        }
        tree[v].mul = 1;
    }
    if(tree[v].add)
    {
        if(tree[left].eq != -1)
        tree[left].eq = (tree[left].eq + tree[v].add)%mod ;
        else
        {
            push_down(left) ;
            tree[left].add = (tree[left].add + tree[v].add)%mod ;
        }
        if(tree[right].eq != -1)
        tree[right].eq = (tree[right].eq + tree[v].add)%mod ;
        else
        {
            push_down(right) ;
            tree[right].add = (tree[right].add + tree[v].add)%mod ;
        }
        tree[v].add = 0 ;
    }
}
void update(int l , int r , int v , int op , int c)
{
    if(l <= tree[v].l && tree[v].r <= r)
    {
        if(op == 3)
        {
            tree[v].add = 0 ;tree[v].mul = 1;
            tree[v].eq  = c ;
            return ;
        }
        if(tree[v].eq != -1)
        {
            if(op == 1)tree[v].eq = (tree[v].eq + c)%mod ;
            else tree[v].eq = (tree[v].eq*c)%mod ;
        }
        else
        {
            push_down(v) ;
            if(op == 1)tree[v].add = (tree[v].add + c)%mod ;
            else tree[v].mul = (tree[v].mul*c)%mod ;
        }
        return ;
    }
    push_down(v) ;
    int mid = (tree[v].l + tree[v].r) >> 1 ;
    if(l <= mid)update(l , r ,left , op , c) ;
    if(r > mid)update(l , r , right , op , c) ;
}
int query(int l , int r , int v , int q)
{
    if(tree[v].l >= l && tree[v].r <= r && tree[v].eq != -1)
    {
        int ans = 1;
        for(int i = 1;i <= q;i++)
        ans = (ans * tree[v].eq)%mod ;
        return (ans*((tree[v].r - tree[v].l + 1)%mod))%mod ;
    }
    push_down(v) ;
    int mid = (tree[v].l  + tree[v].r) >> 1 ;
    if(l > mid)return query(l , r , right, q) ;
    else if(r <= mid)return query(l , r ,left ,q) ;
    else return (query(l , mid , left , q) + query(mid + 1 , r , right , q))%mod ;
}
int main()
{
    //freopen("in.txt" ,"r" , stdin) ;
    int n , m ;
    while(scanf("%d%d" , &n , &m) &&(n+m))
    {
        int op , x ,  y , c;
        build(1 , n , 1) ;
        while(m--)
        {
            scanf("%d%d%d%d" , &op , &x , &y , &c) ;
            if(op == 4)
            printf("%d
" , (query(x, y , 1 , c)%mod)) ;
            else update(x , y , 1 , op , c) ;
        }
    }
    return  0 ;
}

C++代码二

解释

  平方和这样来推:(a + c)2 = a2 + c2 + 2ac  , 即sum2[rt] = sum2[rt] + (r - l + 1) * c * c + 2 * sum1[rt] * c;

  立方和这样推:(a + c)3 = a3 + c3 + 3a(a2 + ac) , 即sum3[rt] = sum3[rt] + (r - l + 1) * c * c * c + 3 * c * (sum2[rt] + sum1[rt] * c);

  几个注意点:add标记取消的时候是置0,mul标记取消的时候是置1;在PushDown()中也也要注意取消标记,如set操作中取消add和mul,mul操作中更新add; 在add操作中要注意sum3 , sum2 , sum1的先后顺序,一定是先sum3 , 然后sum2 , 最后sum1; int容易爆,还是用LL要保险一点; 最后就是运算较多,不要漏掉东西。

当然这种方法有取巧的成分

#include <iostream>
#include <cstdio>
#include <vector>
#include <cmath>
#include <string>
#include <string.h>
#include <algorithm>
using namespace std;
#define LL __int64
typedef long long ll;
#define eps 1e-8
#define INF INT_MAX
#define lson l , m , rt << 1
#define rson m + 1 , r , rt << 1 | 1
const int MOD = 10007; 
const int maxn = 100000 + 5;
const int N = 12;
ll add[maxn << 2] , set[maxn << 2] , mul[maxn << 2];
ll sum1[maxn << 2] , sum2[maxn << 2] , sum3[maxn << 2];
void PushUp(int rt)
{
    sum1[rt] = (sum1[rt << 1] + sum1[rt << 1 | 1]) % MOD;
    sum2[rt] = (sum2[rt << 1] + sum2[rt << 1 | 1]) % MOD;
    sum3[rt] = (sum3[rt << 1] + sum3[rt << 1 | 1]) % MOD;
}
void build(int l , int r , int rt)
{
    add[rt] = set[rt] = 0;
    mul[rt] = 1;
    if(l == r) {
        sum1[rt] = sum2[rt] = sum3[rt] = 0;
        return;
    }
    int m = (l + r) >> 1;
    build(lson);
    build(rson);
    PushUp(rt);
}
void PushDown(int rt , int len)
{
    if(set[rt]) {
        set[rt << 1] = set[rt << 1 | 1] = set[rt];
        add[rt << 1] = add[rt << 1 | 1] = 0;    //注意这个也要下放
        mul[rt << 1] = mul[rt << 1 | 1] = 1;
        ll tmp = ((set[rt] * set[rt]) % MOD) * set[rt] % MOD;
        sum1[rt << 1] = ((len - (len >> 1)) % MOD) * (set[rt] % MOD) % MOD;
        sum1[rt << 1 | 1] = ((len >> 1) % MOD) * (set[rt] % MOD) % MOD;
        sum2[rt << 1] = ((len - (len >> 1)) % MOD) * ((set[rt] * set[rt]) % MOD) % MOD;
        sum2[rt << 1 | 1] = ((len >> 1) % MOD) * ((set[rt] * set[rt]) % MOD) % MOD;
        sum3[rt << 1] = ((len - (len >> 1)) % MOD) * tmp % MOD;
        sum3[rt << 1 | 1] = ((len >> 1) % MOD) * tmp % MOD;
        set[rt] = 0;
    }
    if(mul[rt] != 1) {    //这个就是mul[rt] != 1 , 当时我这里没注意所以TLE了
        mul[rt << 1] = (mul[rt << 1] * mul[rt]) % MOD;
        mul[rt << 1 | 1] = (mul[rt << 1 | 1] * mul[rt]) % MOD;
        if(add[rt << 1])    //注意这个也要下放
            add[rt << 1] = (add[rt << 1] * mul[rt]) % MOD;
        if(add[rt << 1 | 1])
            add[rt << 1 | 1] = (add[rt << 1 | 1] * mul[rt]) % MOD;
        ll tmp = (((mul[rt] * mul[rt]) % MOD * mul[rt]) % MOD);
        sum1[rt << 1] = (sum1[rt << 1] * mul[rt]) % MOD;
        sum1[rt << 1 | 1] = (sum1[rt << 1 | 1] * mul[rt]) % MOD;
        sum2[rt << 1] = (sum2[rt << 1] % MOD) * ((mul[rt] * mul[rt]) % MOD) % MOD;
        sum2[rt << 1 | 1] = (sum2[rt << 1 | 1] % MOD) * ((mul[rt] * mul[rt]) % MOD) % MOD;
        sum3[rt << 1] = (sum3[rt << 1] % MOD) * tmp % MOD;
        sum3[rt << 1 | 1] = (sum3[rt << 1 | 1] % MOD) * tmp % MOD;
        mul[rt] = 1;
    }
    if(add[rt]) {
        add[rt << 1] += add[rt];    //add是+= , mul是*=
        add[rt << 1 | 1] += add[rt];
        ll tmp = (add[rt] * add[rt] % MOD) * add[rt] % MOD;        //注意sum3 , sum2 , sum1的先后顺序
        sum3[rt << 1] = (sum3[rt << 1] + (tmp * (len - (len >> 1)) % MOD) + 3 * add[rt] * ((sum2[rt << 1] + sum1[rt << 1] * add[rt]) % MOD)) % MOD;
        sum3[rt << 1 | 1] = (sum3[rt << 1 | 1] + (tmp * (len >> 1) % MOD) + 3 * add[rt] * ((sum2[rt << 1 | 1] + sum1[rt << 1 | 1] * add[rt]) % MOD)) % MOD;
        sum2[rt << 1] = (sum2[rt << 1] + ((add[rt] * add[rt] % MOD) * (len - (len >> 1)) % MOD) + (2 * sum1[rt << 1] * add[rt] % MOD)) % MOD;
        sum2[rt << 1 | 1] = (sum2[rt << 1 | 1] + (((add[rt] * add[rt] % MOD) * (len >> 1)) % MOD) + (2 * sum1[rt << 1 | 1] * add[rt] % MOD)) % MOD;
        sum1[rt << 1] = (sum1[rt << 1] + (len - (len >> 1)) * add[rt]) % MOD;
        sum1[rt << 1 | 1] = (sum1[rt << 1 | 1] + (len >> 1) * add[rt]) % MOD;
        add[rt] = 0;
    }
}
void update(int L , int R , int c , int ch , int l , int r , int rt)
{
    if(L <= l && R >= r) {
        if(ch == 3) {
            set[rt] = c;
            add[rt] = 0;
            mul[rt] = 1;
            sum1[rt] = ((r - l + 1) * c) % MOD;
            sum2[rt] = ((r - l + 1) * ((c * c) % MOD)) % MOD;
            sum3[rt] = ((r - l + 1) * (((c * c) % MOD) * c % MOD)) % MOD;
        } else if(ch == 2) {
            mul[rt] = (mul[rt] * c) % MOD;
            if(add[rt])
                add[rt] = (add[rt] * c) % MOD;
            sum1[rt] = (sum1[rt] * c) % MOD;
            sum2[rt] = (sum2[rt] * (c * c % MOD)) % MOD;
            sum3[rt] = (sum3[rt] * ((c * c % MOD) * c % MOD)) % MOD;
        } else if(ch == 1) {
            add[rt] += c;
            ll tmp = (((c * c) % MOD * c) % MOD * (r - l + 1)) % MOD;    //(r - l + 1) * c^3
            sum3[rt] = (sum3[rt] + tmp + 3 * c * ((sum2[rt] + sum1[rt] * c) % MOD)) % MOD;
            sum2[rt] = (sum2[rt] + (c * c % MOD * (r - l + 1) % MOD) + 2 * sum1[rt] * c) % MOD;
            sum1[rt] = (sum1[rt] + (r - l + 1) * c) % MOD;
        }
        return;
    }
    PushDown(rt , r - l + 1);
    int m = (l + r) >> 1;
    if(L > m)
        update(L , R , c , ch , rson);
    else if(R <= m)
        update(L , R , c , ch , lson);
    else {
        update(L , R , c , ch , lson);
        update(L , R , c , ch , rson);
    }
    PushUp(rt);
}
ll query(int L , int R , int p , int l , int r , int rt)
{
    if(L <= l && R >= r) {
        if(p == 1)
            return sum1[rt] % MOD;
        else if(p == 2)
            return sum2[rt] % MOD;
        else
            return sum3[rt] % MOD;
    }
    PushDown(rt , r - l + 1);
    int m = (l + r) >> 1;
    if(L > m)
        return query(L , R , p , rson);
    else if(R <= m)
        return query(L , R , p , lson);
    else 
        return (query(L , R , p , lson) + query(L , R , p , rson)) % MOD;
}
int main()
{
    int n , m;
    int a , b , c , ch;
    while(~scanf("%d %d" , &n , &m))
    {
        if(n == 0 && m == 0)
            break;
        build(1 , n , 1);
        while(m--) {
            scanf("%d %d %d %d" , &ch , &a , &b , &c);
            if(ch != 4) {
                update(a , b , c , ch , 1 , n , 1);
            } else {
                printf("%lld
" , query(a , b , c , 1 , n , 1));
            }
        }
    }
    return 0;
}
原文地址:https://www.cnblogs.com/DWVictor/p/11203932.html