BZOJ 3745: [Coci2015]Norma(分治)

题意

给定一个正整数序列 (a_1, a_2, cdots, a_n) ,求

[sum_{i=1}^{n} sum_{j=i}^{n} (j - i + 1) min(a_i,a_{i+1},cdots,a_j) max(a_i,a_{i+1},cdots,a_j) pmod {10^9} ]

(n le 5 imes 10^5, a_i le 10^9)

题解

对于这种求一段区间内所有子区间答案和的东西,我们常常可以考虑分治解决。

通常思路是这样的:

假设我们计算 ([l, r]) 这段区间的答案和。令 (displaystyle mid = lfloor frac{l + r}{2} floor)

我们分治计算全在 ([l, mid]) 以及 ([mid + 1, r]) 区间的答案。

然后计算跨过 (mid) 的区间的答案。也就是左端点在 ([l, mid]) 并且右端点在 ([mid + 1, r]) 中的区间。

最后所有区间答案加起来就行了,正确性是显而易见的,因为一个区间要么全在左边,要么全在右边,要么跨过中点。

接下来难点就在跨过区间的答案计算上。

通常的话我们可以枚举一个端点,快速计算另外一个端点的贡献。

首先可以从 (mid)(l) 枚举这个区间的左端点 (x) ,令 ([x, mid]) 的最小值为 (a) ,最大值为 (b)

两个单调指针 (p,q)(mid) 向右走(最多到 (r) ),分别指向使得 ([mid/x, i]) 最小值/最大值 为 (a) / (b) 的最大位置 (i)

我们把区间右端点 (y) 分三种情况讨论。

  1. 对于 (y in [mid + 1, min {p, q}]) 这一部分 (min = a, max = b) 。那么答案就是

    [sum_{y=mid+1}^{ min {p, q}} (y - x + 1) imes a imes b ]

    这部分用高斯求和优化即可。

  2. 对于 (y in (min {p, q}, max {p, q}]) 这部分,会有两种情况,本质是一样的。

    我们讨论 (p < q) 这种,那么答案就是

    [sum_{y=min {p, q}+1}^{max {p, q}} (min_{k = mid}^{y} a_k) imes(y - x + 1) imes b ]

    我们可以考虑预处理 (displaystyle sum_{y=1}^{n} (min_{k = mid}^{y}a_k)~y) 以及 (displaystyle sum_{y=1}^{n} (min_{k = mid}^{y}a_k)) 就行了。那么每次计算就是前缀和相减,然后乘上系数就行了。

  3. 对于最后一部分 (y in (max {p, q}, r]) ,那么我们类似与上面那个式子处理

    [sum_{y=max {p, q}+1}^{r} (min_{k = mid}^{y} a_k) (max_{k = mid}^{y} a_k) imes(y - x + 1) ]

    那么我们只需要考虑预处理 (displaystyle sum_{y=1}^{n}(min_{k = mid}^{y} a_k)(max_{k = mid}^{y} a_k)~y) 以及 (displaystyle sum_{y=1}^{n}(min_{k = mid}^{y} a_k)(max_{k = mid}^{y} a_k)) ,类似于上面那个式子前缀和相减就行了。

然后这样就可以做完了,每次操作只需要扫一遍区间,所以复杂度就是 (O(n log n)) 的。

总结

考虑所有子区间答案,或者所有点对的答案,可以尝试考虑分治。

然后继续考虑枚举一个端点,另外一个端点用一些特殊结构快速计算就行了。

代码

#include <bits/stdc++.h>

#define For(i, l, r) for(register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for(register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Set(a, v) memset(a, v, sizeof(a))
#define Cpy(a, b) memcpy(a, b, sizeof(a))
#define debug(x) cout << #x << ": " << x << endl
#define DEBUG(...) fprintf(stderr, __VA_ARGS__)

using namespace std;

inline bool chkmin(int &a, int b) {return b < a ? a = b, 1 : 0;}
inline bool chkmax(int &a, int b) {return b > a ? a = b, 1 : 0;}

inline int read() {
    int x = 0, fh = 1; char ch = getchar();
    for (; !isdigit(ch); ch = getchar()) if (ch == '-') fh = -1;
    for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ 48);
    return x * fh;
}

void File() {
#ifdef zjp_shadow
	freopen ("3745.in", "r", stdin);
	freopen ("3745.out", "w", stdout);
#endif
}

const int N = 5e5 + 1e3, inf = 0x7f7f7f7f, Mod = 1e9;

int n, a[N]; int ans = 0;

int Sum(int l, int r) { return (1ll * (r - l + 1) * (r + l) / 2) % Mod; }

int SMM[N], SMMP[N], SMax[N], SMin[N], SMaxp[N], SMinp[N]; int Minv[N], Maxv[N];
// SumMaxMin SunMaxMinPos SumMax SumMin SumMaxPos SumMinpos MinVal MaxVal

inline int Add(int a, int b) {
	return (a += b) >= Mod ? a - Mod : a;
}

void Solve(int l, int r) {
	if (l == r) { ans = Add(ans, 1ll * a[l] * a[r] % Mod); return ; }
	int mid = (l + r) >> 1; Solve(l, mid); Solve(mid + 1, r);

	int p = mid, q = mid; 
	int minl = a[mid], maxl = a[mid];

	SMM[mid] = SMMP[mid] = SMax[mid] = SMin[mid] = SMaxp[mid] = SMinp[mid] = 0;
	For (i, mid + 1, r) {
		if (i == mid + 1) {
			Minv[i] = Maxv[i] = SMax[i] = SMin[i] = a[i];
			SMinp[i] = SMaxp[i] = 1ll * a[i] * i % Mod;
			SMM[i] = 1ll * a[i] * a[i] % Mod; SMMP[i] = 1ll * i * a[i] % Mod * a[i] % Mod;
			continue ;
		}
		Minv[i] = min(a[i], Minv[i - 1]); 
		Maxv[i] = max(a[i], Maxv[i - 1]);

		SMin[i] = Add(SMin[i - 1], Minv[i]);
		SMax[i] = Add(SMax[i - 1], Maxv[i]);
		SMinp[i] = Add(SMinp[i - 1], 1ll * Minv[i] * i % Mod);
		SMaxp[i] = Add(SMaxp[i - 1], 1ll * Maxv[i] * i % Mod);
		SMM[i] = Add(SMM[i - 1], 1ll * Minv[i] * Maxv[i] % Mod);
		SMMP[i] = Add(SMMP[i - 1], 1ll * i * Minv[i] % Mod * Maxv[i] % Mod);
	}

	Fordown (i, mid, l) {
		chkmin(minl, a[i]); chkmax(maxl, a[i]);

		while (p < r && Minv[p + 1] >= minl) ++ p;
		while (q < r && Maxv[q + 1] <= maxl) ++ q;

		int gapl = min(p, q), gapr = max(p, q);
		ans = Add(ans, 1ll * Sum(mid - i + 2, gapl - i + 1) * minl % Mod * maxl % Mod);
		if (p < q) ans = Add(ans, (((SMinp[q] - SMinp[p]) - 1ll * (SMin[q] - SMin[p]) * (i - 1)) % Mod + Mod) * maxl % Mod);
		if (p > q) ans = Add(ans, (((SMaxp[p] - SMaxp[q]) - 1ll * (SMax[p] - SMax[q]) * (i - 1)) % Mod + Mod) * minl % Mod);
		ans = Add(ans, (((SMMP[r] - SMMP[gapr]) - 1ll * (SMM[r] - SMM[gapr]) * (i - 1)) % Mod + Mod) % Mod);
	}
}

int main () {

	File();
	n = read(); For (i, 1, n) a[i] = read();
	Solve(1, n);
	printf ("%d
", ans);

	return 0;
}
原文地址:https://www.cnblogs.com/zjp-shadow/p/9483629.html