线段树 B数据结构 牛客练习赛28

链接:https://ac.nowcoder.com/acm/contest/200/B
来源:牛客网

题目描述

qn姐姐最好了~
    qn姐姐给你了一个长度为n的序列还有m次操作让你玩,
    1 l r 询问区间[l,r]内的元素和
    2 l r 询问区间[l,r]内的元素的平方 
    3 l r x 将区间[l,r]内的每一个元素都乘上x
    4 l r x 将区间[l,r]内的每一个元素都加上x

输入描述:

第一行两个数n,m

接下来一行n个数表示初始序列

就下来m行每行第一个数为操作方法opt,

若opt=1或者opt=2,则之后跟着两个数为l,r

若opt=3或者opt=4,则之后跟着三个数为l,r,x

操作意思为题目描述里说的

输出描述:

对于每一个操作1,2,输出一行表示答案
示例1

输入

复制
5 6
1 2 3 4 5
1 1 5
2 1 5
3 1 2 1
4 1 3 2
1 1 4
2 2 3

输出

复制
15
55
16
41

备注:

对于100%的数据 n=10000,m=200000 (注意是等于号)

保证所有询问的答案在long long 范围内




这个比较简单,但是出现了一个很难找的bug。

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <queue>
#include <algorithm>
#include <iostream>
#include <vector>
#include <map>
#define inf 0x3f3f3f3f
using namespace std;
typedef long long ll;
const int maxn = 2e5 + 1000;

ll a[maxn];

struct node
{
    int l, r;
    ll lazyc, lazyadd;
    ll sum, all;
}tree[maxn*4];

void push_up(int id)
{
    tree[id].sum = tree[id << 1].sum + tree[id << 1 | 1].sum;
    tree[id].all = tree[id << 1].all + tree[id << 1 | 1].all;
}

void push_down(int id)
{
    if(tree[id].lazyc>1)
    {
        tree[id << 1].sum = tree[id << 1].sum*tree[id].lazyc;
        tree[id << 1 | 1].sum = tree[id << 1 | 1].sum*tree[id].lazyc;
        tree[id << 1].all = tree[id << 1].all*tree[id].lazyc*tree[id].lazyc;
        tree[id << 1 | 1].all = tree[id << 1 | 1].all*tree[id].lazyc*tree[id].lazyc;
        tree[id << 1].lazyc *= tree[id].lazyc;
        tree[id << 1 | 1].lazyc *= tree[id].lazyc;
        tree[id].lazyc = 1;
    }
    if(tree[id].lazyadd)
    {
        tree[id << 1].all += tree[id << 1].sum * 2*tree[id].lazyadd + (tree[id << 1].r - tree[id << 1].l + 1)*tree[id].lazyadd*tree[id].lazyadd;
        tree[id << 1 | 1].all += tree[id << 1 | 1].sum * 2 *tree[id].lazyadd+ (tree[id << 1 | 1].r - tree[id << 1 | 1].l + 1)*tree[id].lazyadd*tree[id].lazyadd;
        tree[id << 1].sum += (tree[id<<1].r - tree[id<<1].l + 1)*tree[id].lazyadd;
        tree[id << 1 | 1].sum += (tree[id << 1 | 1].r - tree[id << 1 | 1].l + 1)*tree[id].lazyadd;
        tree[id << 1].lazyadd += tree[id].lazyadd;
        tree[id << 1 | 1].lazyadd += tree[id].lazyadd;
        tree[id].lazyadd = 0;
    }
}

void build(int id,int l,int r)
{
    tree[id].l = l;
    tree[id].r = r;
    tree[id].lazyc = 1;
    tree[id].lazyadd = 0;
    if(l==r)
    {
        tree[id].sum = a[l];
        tree[id].all = a[l] * a[l];
        return;
    }
    int mid = (l + r) >> 1;
    build(id << 1, l, mid);
    build(id << 1 | 1, mid + 1, r);
    push_up(id);
}

void updatec(int id,int l,int r,int x)
{
    push_down(id);
    if(l<=tree[id].l&&r>=tree[id].r)
    {
        tree[id].lazyc *= x;
        tree[id].lazyadd *= x;
        tree[id].sum = tree[id].sum*x;
        tree[id].all = tree[id].all*x*x;
        return;
    }
    int mid = (tree[id].l + tree[id].r) >> 1;
    if (l <= mid) updatec(id << 1, l, r, x);
    if (r > mid) updatec(id << 1 | 1, l, r, x);
    push_up(id);
}

void updateadd(int id,int l,int r,int x)
{
    push_down(id);
    if(l<=tree[id].l&&r>=tree[id].r)
    {
        tree[id].lazyadd += x;
        tree[id].all += tree[id].sum * 2*x + (tree[id].r - tree[id].l + 1)*x*x;
        tree[id].sum += (tree[id].r - tree[id].l + 1)*x;
        return;
    }
    int mid = (tree[id].l + tree[id].r) >> 1;
    if (l <= mid) updateadd(id << 1, l, r, x);
    if (r > mid) updateadd(id << 1 | 1, l, r, x);//这里的id<<1|1忘记+1了,就写成了id<<1
    push_up(id);
}


ll querysum(int id,int l,int r)
{
    
    if(l<=tree[id].l&&r>=tree[id].r)
    {
        return tree[id].sum;
    }
    ll ans = 0;
    push_down(id);
    int mid = (tree[id].l + tree[id].r) >> 1;
    if (l <= mid) ans += querysum(id << 1, l, r);
    if (r > mid) ans += querysum(id << 1 | 1, l, r);
    return ans;
}

ll queryc(int id,int l,int r)
{
    if(l<=tree[id].l&&r>=tree[id].r)
    {
        return tree[id].all;
    }
    push_down(id);
    ll ans = 0;
    int mid = (tree[id].l + tree[id].r) >> 1;
    if (l <= mid) ans += queryc(id << 1, l, r);
    if (r > mid) ans += queryc(id << 1 | 1, l, r);
    return ans;
}

int main()
{
    int n, m;
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) scanf("%lld", &a[i]);
    build(1, 1, n);
    int opt, l, r, x;
    for(int i=1;i<=m;i++)
    {
        scanf("%d%d%d", &opt, &l, &r);
        if(opt==1)
        {
            ll ans = querysum(1, l, r);
            printf("%lld
", ans);
        }
        if(opt==2)
        {
            ll ans = queryc(1, l, r);
            printf("%lld
", ans);
        }
        if(opt==3)
        {
            scanf("%d", &x);
            updatec(1, l, r, x);
        }
        if(opt==4)
        {
            scanf("%d", &x);
            updateadd(1, l, r, x);
        }
    }
    return 0;
}
原文地址:https://www.cnblogs.com/EchoZQN/p/10846778.html