维护序列

题目描述

思路

借鉴

代码

#include <cstdio>
#define lc k<<1
#define rc k<<1|1

using namespace std;

int n, m, w;
long long sum[100005 << 2], at[100005];
long long add[100005 << 2], mul[100005 << 2];
inline int read() {
	int s = 0;
	char ch = getchar();
	while (ch < '0' || ch > '9') ch = getchar();
	while (ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
	return s;
}
void build(int k, int l, int r) {
	add[k] = 0, mul[k] = 1;
	if (l == r) {
		sum[k] = at[l] % m;
		return;
	}
	int mid = l + r >> 1;
	build(lc, l, mid);
	build(rc, mid + 1, r);
	sum[k] = (sum[lc] + sum[rc]) % m;
}

void pushdown(int k, int l, int r) {
	if (mul[k] != 1) {
		mul[lc] = (mul[lc] * mul[k]) % m;
		add[lc] = (add[lc] * mul[k]) % m;
		sum[lc] = (sum[lc] * mul[k]) % m;
		
		mul[rc] = (mul[rc] * mul[k]) % m;
		add[rc] = (add[rc] * mul[k]) % m;
		sum[rc] = (sum[rc] * mul[k]) % m;
		
		mul[k] = 1;
	}
	int mid = l + r >> 1;
	if (add[k] != 0) {
		add[lc] = (add[lc] + add[k]) % m;
		sum[lc] = (sum[lc] + (mid - l + 1) * add[k]) % m;
		
		add[rc] = (add[rc] + add[k]) % m;
		sum[rc] = (sum[rc] + (r - mid) * add[k]) % m;
		
		add[k] = 0;
	}
}

void upmul(int k, int l, int r, int x, int y, int z) {
	if (x <= l && r <= y) {
		sum[k] = (sum[k] * z) % m;
		mul[k] = (mul[k] * z) % m;
		add[k] = (add[k] * z) % m;
		return;
	}
	pushdown(k, l, r);
	int mid = l + r >> 1;
	if (x <= mid) upmul(lc, l, mid, x, y, z);
	if (y > mid) upmul(rc, mid + 1, r, x, y, z);
	sum[k] = (sum[lc] + sum[rc]) % m;
}

void upadd(int k, int l, int r, int x, int y, int z) {
	if (x <= l && r <= y) {
		sum[k] = (sum[k] + (r - l + 1) * z) % m;
		add[k] = (add[k] + z) % m;
		return;
	}
	pushdown(k, l, r);
	int mid = l + r >> 1;
	if (x <= mid) upadd(lc, l, mid, x, y, z);
	if (y > mid) upadd(rc, mid + 1, r, x, y, z);
	sum[k] = (sum[lc] + sum[rc]) % m;
}

long long query(int k, int l, int r, int x, int y) {
	if (x <= l && r <= y) return sum[k] % m;
	int mid = l + r >> 1, res = 0;
	pushdown(k, l, r);
	if (x <= mid) res = query(lc, l, mid, x, y) % m;
	if (y > mid) res = (res + query(rc, mid + 1, r, x, y)) % m;
	return res % m;
}
int main() {
	n = read(), m = read();
	for (int i = 1; i <= n; ++i) at[i] = read();
	build(1, 1, n);
	w = read();
	for (int i = 1, a, b, c, d, e; i <= w; ++i) {
		a = read();
		if (a == 1 || a == 2) {
			b = read(), c = read(), d = read();
			if (a == 1) upmul(1, 1, n, b, c, d % m);
			else upadd(1, 1, n, b, c, d % m);
		} else {
			b = read(), c = read();
			printf("%lld
", query(1, 1, n, b, c));
		}
	}
	return 0;
}

这是用queue保存的tag,容易超时

#include <cstdio>
#include <cmath>
#include <queue>
#include <algorithm>
using namespace std;

int n, m, w;
int arr[100005 << 2], at[100005];
queue<pair<int, int> > q[100005 << 2];
inline int read() {
    int s = 0;
    char ch = getchar();
    while (ch < '0' || ch > '9') ch = getchar();
    while (ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
    return s;
}
inline void build(int k, int l, int r) {
    if (l == r) {
        arr[k] = at[l];
        return;
    }
    int mid = l + r >> 1;
    build(k << 1, l, mid);
    build(k << 1 | 1, mid + 1, r);
    arr[k] = (0LL + arr[k << 1] + arr[k << 1 | 1]) % m;
}

inline void trans(int k, int l, int r, pair<int, int> p) {
    q[k].push(p);
    if (p.first == 1)
        arr[k] = 1LL * arr[k] * p.second % m;
    else
        arr[k] = (0LL + arr[k] + (r - l + 1) * p.second) % m;
}
inline void pushdown(int k, int l, int r) {
    int L = k << 1, R = k << 1 | 1, mid = l + r >> 1;
    while (!q[k].empty()) {
        trans(L, l, mid, q[k].front());
        trans(R, mid + 1, r, q[k].front());
        q[k].pop();
    }
}
inline void change(int k, int l, int r, int x, int y, int p, int q) {
    if (x <= l && r <= y) {
        trans(k, l, r, make_pair(p, q));
        return;
    }
    pushdown(k, l, r);
    int mid = l + r >> 1;
    if (x <= mid)
        change(k << 1, l, mid, x, y, p, q);
    if (y > mid)
        change(k << 1 | 1, mid + 1, r, x, y, p, q);
    arr[k] = (0LL + arr[k << 1] + arr[k << 1 | 1]) % m;
}
inline int query(int k, int l, int r, int x, int y) {
    if (x <= l && r <= y)
        return arr[k];
    int mid = l + r >> 1;
    int res = 0;
    pushdown(k, l, r);
    if (x <= mid)
        res = query(k << 1, l, mid, x, y);
    if (y > mid)
        res = (0LL + res + query(k << 1 | 1, mid + 1, r, x, y)) % m;
    return res;
}
int main() {
    n = read(), m = read();
    for (int i = 1; i <= n; ++i) at[i] = read();
    build(1, 1, n);
    w = read();
    for (int i = 1, a, b, c, d, e; i <= w; ++i) {
        a = read();
        if (a == 1 || a == 2) {
            b = read(), c = read(), d = read();
            if (a == 1)
                change(1, 1, n, b, c, 1, d);
            else
                change(1, 1, n, b, c, 2, d);
        } else {
            b = read(), c = read();
            printf("%d
", query(1, 1, n, b, c));
        }
    }
    return 0;
}
原文地址:https://www.cnblogs.com/liuzz-20180701/p/11497893.html