记一道乘法&加法线段树(模版题)

P3373 【模板】线段树 2

做法:

两个标记,一个标记是乘,一个是加,每次做乘法时,将前面的加法乘上当前的乘数,然后转移就可以先乘后加

(a*b+c)*d=a*b*d+c*d

如上式,a是原数, 操作顺序是乘b,加c,乘d。

两个标记的变换为 [1,0] => [b,0] => [b,c] => [b*d,c*d]

#include<bits/stdc++.h>
using namespace std;
#define ll long long
int p;
const int MAXN=1e5+5;
ll a[MAXN];
ll seg[MAXN<<2][3];
/*
0:区间和
1:乘
2:加
*/
void build(int i,int l,int r){
    seg[i][1]=1;
    if(l==r){
        seg[i][0]=a[l]%p;
        return;
    }
    int mid=l+(r-l)/2;
    build(i*2,l,mid);
    build(i*2+1,mid+1,r);
    seg[i][0]=seg[i*2][0]+seg[i*2+1][0];
    seg[i][0]%=p;
}

void push_down(int i,int l,int r){
    //先乘,后加
    seg[i*2][1]*=seg[i][1];seg[i*2][1]%=p;
    seg[i*2][2]*=seg[i][1];seg[i*2][2]%=p;
    seg[i*2+1][1]*=seg[i][1];seg[i*2+1][1]%=p;
    seg[i*2+1][2]*=seg[i][1];seg[i*2+1][2]%=p;
    //
    seg[i*2][2]+=seg[i][2];seg[i*2][2]%=p;
    seg[i*2+1][2]+=seg[i][2];seg[i*2+1][2]%=p;
    int mid=l+(r-l)/2;
    seg[i*2][0]=(seg[i*2][0]*seg[i][1]%p+seg[i][2]*(mid-l+1))%p;
    seg[i*2+1][0]=(seg[i*2+1][0]*seg[i][1]%p+seg[i][2]*(r-mid))%p;
    seg[i][1]=1;
    seg[i][2]=0;
}

void add(int i,int l,int r,int x,int y,int k){//加上k
    if(x<=l&&r<=y){
        seg[i][0]+=1ll*k*(r-l+1);seg[i][0]%=p;
        seg[i][2]+=k;
        return;
    }
    int mid=l+(r-l)/2;
    push_down(i,l,r);
    if(x<=mid) add(i*2,l,mid,x,y,k);
    if(y>mid) add(i*2+1,mid+1,r,x,y,k);
    seg[i][0]=seg[i*2][0]+seg[i*2+1][0];
    seg[i][0]%=p;
}

void mul(int i,int l,int r,int x,int y,int k){//乘k
    if(x<=l&&r<=y){
        seg[i][0]*=k;seg[i][0]%=p;
        seg[i][1]*=k;seg[i][1]%=p;
        seg[i][2]*=k;seg[i][2]%=p;
        return;
    }
    int mid=l+(r-l)/2;
    push_down(i,l,r);
    if(x<=mid) mul(i*2,l,mid,x,y,k);
    if(y>mid) mul(i*2+1,mid+1,r,x,y,k);
    seg[i][0]=seg[i*2][0]+seg[i*2+1][0];
    seg[i][0]%=p;
}

ll query(int i,int l,int r,int x,int y){
    if(r<x||l>y)
        return 0;
    if(x<=l&&r<=y)
        return seg[i][0];
    int mid=l+(r-l)/2;
    push_down(i,l,r);
    return (query(i*2,l,mid,x,y)+query(i*2+1,mid+1,r,x,y))%p;
}

int main(){
    int n,m;
    cin>>n>>m>>p;
    for(int i=1;i<=n;i++){
        cin>>a[i];
    }
    build(1,1,n);
    while(m--){
        int op,x,y,k;
        cin>>op>>x>>y;
        if(op==1){
            cin>>k;
            mul(1,1,n,x,y,k);
            continue;
        }
        if(op==2){
            cin>>k;
            add(1,1,n,x,y,k);
            continue;
        }
        cout<<query(1,1,n,x,y)<<endl;
    }

}
原文地址:https://www.cnblogs.com/xuanzo/p/15252179.html