注意优先顺序:区间赋值>区间乘>区间加
一定记得(push down)时乘法标记判断是(if (t[p].mul != 1))而不是(if (t[p].mul))(因为可能会乘(0))
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ls(x) (x << 1)
#define rs(x) (x << 1 | 1)
ll n, q, a[100005];
struct node
{
ll l, r, setv, addv, sumv;
}t[400005];
ll read()
{
ll x = 0ll, fl = 1ll; char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') fl = -1ll; ch = getchar();}
while (ch >= '0' && ch <= '9') {x = x * 10ll + ch - '0'; ch = getchar();}
return x * fl;
}
void push_up(ll p)
{
t[p].sumv = t[ls(p)].sumv + t[rs(p)].sumv;
return;
}
void push_down(ll p)
{
if (t[p].setv != -1)
{
t[ls(p)].addv = t[rs(p)].addv = 0;
t[ls(p)].setv = t[rs(p)].setv = t[p].setv;
t[ls(p)].sumv = t[p].setv * (t[ls(p)].r - t[ls(p)].l + 1);
t[rs(p)].sumv = t[p].setv * (t[rs(p)].r - t[rs(p)].l + 1);
t[p].setv = -1;
}
if (t[p].addv)
{
t[ls(p)].addv += t[p].addv;
t[rs(p)].addv += t[p].addv;
t[ls(p)].sumv += t[p].addv * (t[ls(p)].r - t[ls(p)].l + 1);
t[rs(p)].sumv += t[p].addv * (t[rs(p)].r - t[rs(p)].l + 1);
t[p].addv = 0;
}
return;
}
void build(ll p, ll l0, ll r0)
{
t[p].l = l0; t[p].r = r0; t[p].setv = -1;
if (l0 == r0)
{
t[p].sumv = a[l0];
return;
}
ll mid = (l0 + r0) / 2ll;
build(ls(p), l0, mid);
build(rs(p), mid + 1, r0);
push_up(p);
return;
}
void update(ll p, ll l0, ll r0, ll d, ll tp)
{
if (l0 <= t[p].l && t[p].r <= r0)
{
if (tp == 1) // set
{
t[p].setv = d;
t[p].addv = 0;
t[p].sumv = d * (t[p].r - t[p].l + 1);
return;
}
else // add
{
t[p].addv += d;
t[p].sumv += d * (t[p].r - t[p].l + 1);
return;
}
}
push_down(p);
ll mid = (t[p].l + t[p].r) / 2ll;
if (l0 <= mid) update(ls(p), l0, r0, d, tp);
if (r0 > mid) update(rs(p), l0, r0, d, tp);
push_up(p);
return;
}
ll query(ll p, ll l0, ll r0)
{
if (l0 <= t[p].l && t[p].r <= r0) return t[p].sumv;
push_down(p);
ll mid = (t[p].l + t[p].r) / 2ll, sum = 0;
if (l0 <= mid) sum += query(ls(p), l0, r0);
if (r0 > mid) sum += query(rs(p), l0, r0);
return sum;
}
int main()
{
n = read(); q = read();
for (ll i = 1; i <= n; i ++ )
a[i] = read();
build(1, 1, n);
while (q -- )
{
ll opt = read();
if (opt == 1) // query
{
ll l0 = read(), r0 = read();
printf("%lld
", query(1, l0, r0));
}
else if (opt == 2) // set
{
ll l0 = read(), r0 = read(), x = read();
update(1, l0, r0, x, 1);
}
else // add
{
ll l0 = read(), r0 = read(), x = read();
update(1, l0, r0, x, 2);
}
}
return 0;
}