线段树模板(区间加法,区间赋值,区间求和)

注意优先顺序:区间赋值>区间乘>区间加
一定记得(push down)时乘法标记判断是(if (t[p].mul != 1))而不是(if (t[p].mul))(因为可能会乘(0)

#include <bits/stdc++.h>

using namespace std;

#define ll long long
#define ls(x) (x << 1)
#define rs(x) (x << 1 | 1)

ll n, q, a[100005];

struct node
{
	ll l, r, setv, addv, sumv;
}t[400005];

ll read()
{
	ll x = 0ll, fl = 1ll; char ch = getchar();
	while (ch < '0' || ch > '9') { if (ch == '-') fl = -1ll; ch = getchar();}
	while (ch >= '0' && ch <= '9') {x = x * 10ll + ch - '0'; ch = getchar();}
	return x * fl;
}

void push_up(ll p)
{
	t[p].sumv = t[ls(p)].sumv + t[rs(p)].sumv;
	return;
}

void push_down(ll p)
{
	if (t[p].setv != -1)
	{
		t[ls(p)].addv = t[rs(p)].addv = 0;
		t[ls(p)].setv = t[rs(p)].setv = t[p].setv;
		t[ls(p)].sumv = t[p].setv * (t[ls(p)].r - t[ls(p)].l + 1);
		t[rs(p)].sumv = t[p].setv * (t[rs(p)].r - t[rs(p)].l + 1);
		t[p].setv = -1;
	}
	if (t[p].addv)
	{
		t[ls(p)].addv += t[p].addv;
		t[rs(p)].addv += t[p].addv;
		t[ls(p)].sumv += t[p].addv * (t[ls(p)].r - t[ls(p)].l + 1);
		t[rs(p)].sumv += t[p].addv * (t[rs(p)].r - t[rs(p)].l + 1);
		t[p].addv = 0;
	}
	return;
}

void build(ll p, ll l0, ll r0)
{
	t[p].l = l0; t[p].r = r0; t[p].setv = -1;
	if (l0 == r0)
	{
		t[p].sumv = a[l0];
		return;
	}
	ll mid = (l0 + r0) / 2ll;
	build(ls(p), l0, mid);
	build(rs(p), mid + 1, r0);
	push_up(p);
	return;
}

void update(ll p, ll l0, ll r0, ll d, ll tp)
{
	if (l0 <= t[p].l && t[p].r <= r0)
	{
		if (tp == 1) // set
		{
			t[p].setv = d;
			t[p].addv = 0;
			t[p].sumv = d * (t[p].r - t[p].l + 1);
			return;
		} 
		else // add
		{
			t[p].addv += d;
			t[p].sumv += d * (t[p].r - t[p].l + 1);
			return;
		}
	}
	push_down(p);
	ll mid = (t[p].l + t[p].r) / 2ll;
	if (l0 <= mid) update(ls(p), l0, r0, d, tp);
	if (r0 > mid) update(rs(p), l0, r0, d, tp);
	push_up(p);
	return;
}

ll query(ll p, ll l0, ll r0)
{
	if (l0 <= t[p].l && t[p].r <= r0) return t[p].sumv;
	push_down(p);
	ll mid = (t[p].l + t[p].r) / 2ll, sum = 0;
	if (l0 <= mid) sum += query(ls(p), l0, r0);
	if (r0 > mid) sum += query(rs(p), l0, r0);
	return sum;
}

int main()
{
	n = read(); q = read();
	for (ll i = 1; i <= n; i ++ )
		a[i] = read();
	build(1, 1, n);
	while (q -- )
	{
		ll opt = read();
		if (opt == 1) // query
		{
			ll l0 = read(), r0 = read();
			printf("%lld
", query(1, l0, r0));
		}
		else if (opt == 2) // set
		{
			ll l0 = read(), r0 = read(), x = read();
			update(1, l0, r0, x, 1);
		}
		else // add
		{
			ll l0 = read(), r0 = read(), x = read();
			update(1, l0, r0, x, 2);
		}
	}
	return 0;
}
原文地址:https://www.cnblogs.com/andysj/p/13935785.html