Codeforces 719E (线段树教做人系列) 线段树维护矩阵

题面简洁明了,一看就懂

做了这个题之后,才知道怎么用线段树维护递推式。递推式的递推过程可以看作两个矩阵相乘,假设矩阵A是初始值矩阵,矩阵B是变换矩阵,求第n项相当于把矩阵B乘了n - 1次。

那么我们线段树中每个点维护把矩阵B乘了多少次,懒标记下放的时候用快速幂维护sum。

#include <bits/stdc++.h>
#define LL long long
#define ls(x) (x << 1)
#define rs(x) ((x << 1) | 1)
using namespace std;
const LL mod = 1000000007;
const int maxn = 100010;
struct Matrix {
	static const int len = 2;
	LL x[len][len];
	
	void init() {
		memset(x, 0, sizeof(x));
		for (int i = 0; i < len; i++)
			x[i][i] = 1;
	}
	
	void zero() {
		memset(x, 0, sizeof(x));
	}
	
	Matrix operator * (const Matrix& m) const {
		Matrix ans;
		ans.zero();
		for (int i = 0; i < len; i++)
			for (int j = 0; j < len; j++)
				for (int k = 0; k < len; k++)
					ans.x[i][j] = (ans.x[i][j] + x[i][k] * m.x[k][j]) % mod;
		return ans;
	}
	
	Matrix operator + (const Matrix& m) const {
		Matrix ans;
		ans.zero();
		for (int i = 0; i < len; i++)
			for (int j = 0; j < len; j++)
				ans.x[i][j] = (x[i][j] + m.x[i][j]) % mod;
		return ans;
	}
	
	Matrix operator ^ (int b) const {
		Matrix ans, a;
		ans.init();
		memcpy(a.x, x, sizeof(x));
		for (; b; b >>= 1) {
			if(b & 1) ans = ans * a;
			a = a * a;
		}
		return ans;
	}
};

Matrix mul , tmp, trans ;
int a[maxn];
struct SegementTree {
	int lz;
	Matrix sum, flag;
};

SegementTree tr[maxn * 4];

void maintain(int o) {
	tr[o].sum = tr[ls(o)].sum + tr[rs(o)].sum;
}

void pushdown(int o) {
	if(tr[o].lz) {
		tr[ls(o)].sum = tr[ls(o)].sum * tr[o].flag;
		tr[rs(o)].sum = tr[rs(o)].sum * tr[o].flag;
		tr[ls(o)].flag = tr[ls(o)].flag * tr[o].flag;
		tr[rs(o)].flag = tr[rs(o)].flag * tr[o].flag;
		tr[o].lz = 0;
		tr[ls(o)].lz = 1;
		tr[rs(o)].lz = 1;
		tr[o].flag.init();
	}
}

void build(int o, int l, int r) {
	tr[o].sum.zero();
	tr[o].lz = 0;
	tr[o].flag.init();
	if(l == r) {
		tr[o].sum = trans * ( mul ^ (a[l] - 1));
		return;
	}
	int mid = (l + r) >> 1;
	build(ls(o), l, mid);
	build(rs(o), mid + 1, r);
	maintain(o);
}

void update(int o, int l, int r, int ql, int qr, Matrix now) {
	if(l >= ql && r <= qr) {
		tr[o].sum = tr[o].sum * now;
		tr[o].flag = tr[o].flag * now;
		tr[o].lz = 1;
		return;
	}
	pushdown(o);
	int mid = (l + r) >> 1;
	if(ql <= mid) update(ls(o), l, mid, ql, qr, now);
	if(qr > mid) update(rs(o), mid + 1, r, ql, qr, now);
	maintain(o);
}

LL query(int o, int l, int r, int ql, int qr) {
	if(l >= ql && r <= qr) {
		return tr[o].sum.x[0][1];
	}
	pushdown(o);
	int mid = (l + r) >> 1;
	LL ans = 0;
	if(ql <= mid) ans = (ans + query(ls(o), l, mid, ql, qr)) % mod;
	if(qr > mid) ans = (ans + query(rs(o), mid + 1, r, ql, qr)) % mod;
	return ans;
}

int main() {
	int n, m, op, l, r;
	LL x;
	trans.zero();
	trans.x[0][1] = 1;
	mul.x[0][1] = mul.x[1][0] = mul.x[1][1] = 1;
	mul.x[0][0] = 0;
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n; i++) {
		scanf("%d", &a[i]);
	}
	build(1, 1, n);
	for (int i = 1; i <= m; i++) {
		scanf("%d%d%d", &op, &l, &r);
		if(op == 1) {
			scanf("%lld", &x);
			tmp = (mul ^ x);
			update(1, 1, n, l, r, tmp);
		} else {
			printf("%lld
", query(1, 1, n, l, r));
		}
	}
}

  

原文地址:https://www.cnblogs.com/pkgunboat/p/10608454.html