luogu3373 【模板】线段树2

题目大意:

已知一个数列,你需要进行下面三种操作:
1.将某区间每一个数乘上x
2.将某区间每一个数加上x
3.求出某区间每一个数的和

本线段树的标记是个二元组:add和mul,其代表将一个线段中的每一个点乘以mul再加add。设区间长度为x,原来区间和为sum。如果两个标记要叠加,标记叠加前区间上的和将是sum*mul+add,叠加后的值将是(sum*mul+add)*mul'+add'=mul*mul'*sum+add*mul'+add'。所以将mul*=mul', add=add*mul'+add'即可。

注意:

  • 尽管数据关于P取模了,但是因为有数据相乘的操作,所以程序中所有的值类型都要是long long
  • 宏定义ModPlus, ModMult时,如ModMult,不要写成((x%P)*(y%P))%P,应该写成(x*y)%P,否则就被卡常数了。
#include <cstdio>
#include <cstring>
#include <cassert>
using namespace std;

const int MAX_RANGE=100010, MAX_NODE = MAX_RANGE * 4;
#define LOOP(i, n) for(int i=1; i<=n; i++)
long long P, TotRange;
long long OrgData[MAX_RANGE];

struct RangeTree
{
private:
#define ModPlus(x, y) ((x)%P+(y)%P)%P
#define ModMult(x, y) ((x)%P*(y)%P)%P
#define lSon cur*2, l, mid
#define rSon cur*2+1, mid+1, r
#define Lson cur*2, sl, mid, al, ar
#define Rson cur*2+1, mid+1, sr, al, ar

	struct Tag
	{
		long long add, mul;
		Tag() {}
		Tag(int m, int a):mul(m),add(a){}
		void Refresh(Tag x) { mul = ModMult(mul, x.mul); add = ModMult(add, x.mul); add = ModPlus(add, x.add); }
		void Clear() { add = 0; mul = 1; }
		int GetSum(int sum, int l, int r) { return ModPlus(ModMult(sum, mul), ModMult(add, (r - l + 1))); }
	};
	Tag _tags[MAX_NODE];
	long long Sum[MAX_NODE];

	void PushDown(int cur, int l, int r)
	{
		if (_tags[cur].add != 0 || _tags[cur].mul != 1)
		{
			int mid = (l + r) / 2;
			Sum[cur * 2] = _tags[cur].GetSum(Sum[cur * 2], l, mid);
			Sum[cur * 2 + 1] = _tags[cur].GetSum(Sum[cur * 2 + 1], mid + 1, r);
			_tags[cur * 2].Refresh(_tags[cur]);
			_tags[cur * 2 + 1].Refresh(_tags[cur]);
			_tags[cur].Clear();
		}
	}

	void PullUp(int cur)
	{
		Sum[cur] = ModPlus(Sum[cur * 2], Sum[cur * 2 + 1]);
	}

	void Update(int cur, int sl, int sr, int al, int ar, int op, int value)
	{
		assert(al <= ar && sl <= sr && al <= sr && ar >= sl);
		if (al <= sl && sr <= ar)
		{
			if (op == 1)
			{
				Sum[cur] = ModMult(Sum[cur], value);
				_tags[cur].Refresh(Tag(value, 0));
			}
			else if (op == 2)
			{
				Sum[cur] = ModPlus(Sum[cur], (sr - sl + 1)*value);
				_tags[cur].Refresh(Tag(1, value));
			}
			return;
		}
		PushDown(cur, sl, sr);
		int mid = (sl + sr) / 2;
		if (al <= mid)
			Update(Lson, op, value);
		if (ar > mid)
			Update(Rson, op, value);
		PullUp(cur);
	}

	int Query(int cur, int sl, int sr, int al, int ar)
	{
		assert(al <= ar && sl <= sr && al <= sr && ar >= sl);
		if (al <= sl && sr <= ar)
			return Sum[cur];
		PushDown(cur, sl, sr);
		int mid = (sl + sr) / 2, ans = 0;
		if (al <= mid)
			ans = ModPlus(ans, Query(Lson));
		if (ar > mid)
			ans = ModPlus(ans, Query(Rson));
		PullUp(cur);
		return ans;
	}

	void SetEachNode(long long *a, int cur, int l, int r)
	{
		_tags[cur] = Tag(1, 0);
		if (l == r)
		{
			Sum[cur] = a[l];
			return;
		}
		int mid = (l + r) / 2;
		SetEachNode(a, lSon);
		SetEachNode(a, rSon);
		PullUp(cur);
	}

public:
	RangeTree() {}

	void SetEachNode(long long *a)
	{
		SetEachNode(a, 1, 1, TotRange);
	}

	void Update(int l, int r, int op, int value)
	{
		Update(1, 1, TotRange, l, r, op, value);
	}

	long long Query(int l, int r)
	{
		return Query(1, 1, TotRange, l, r);
	}
}g;

int main()
{
	int opCnt, op, l, r, val;
	scanf("%lld%d%lld", &TotRange, &opCnt, &P);
	LOOP(i, TotRange)
		scanf("%lld", OrgData + i);
	g.SetEachNode(OrgData);
	while (opCnt--)
	{
		scanf("%d", &op);
		switch (op)
		{
		case 1://Mult
			scanf("%d%d%d", &l, &r, &val);
			g.Update(l, r, 1, val);
			break;
		case 2://Plus
			scanf("%d%d%d", &l, &r, &val);
			g.Update(l, r, 2, val);
			break;
		case 3://Query
			scanf("%d%d", &l, &r);
			printf("%lld
", g.Query(l, r));
			break;
		}
	}
	return 0;
}

  

原文地址:https://www.cnblogs.com/headboy2002/p/8646870.html