JZOJ 4418. 【HNOI2016模拟4.1】Prime的把妹计划(单调栈+线段树)

JZOJ 4418. 【HNOI2016模拟4.1】Prime的把妹计划

题目大意

  • 给定序列 A 1.. N A_{1..N} A1..N,分别求出 Q Q Q组询问区间 [ L , R ] [L,R] [L,R]中最长的连续子序列 A l . . r A_{l..r} Al..r的长度,须满足该子序列中任意元素都在 [ A l , A r ] ( A l ≤ A r ) [A_l,A_r](A_l≤A_r) [Al,Ar](AlAr) [ A r , A l ] ( A r ≤ A l ) [A_r,A_l](A_r≤A_l) [Ar,Al](ArAl)中,也就是区间最大/最小值分别在左右端点。
  • N , Q ≤ 5 ∗ 1 0 5 N,Q≤5*10^5 N,Q5105.

题解

  • 这道题的题意很容易让人误解,做的时候理解错了两次。。。
  • 注意需要保证连续。
  • 不妨先考虑 A l ≤ A r A_l≤A_r AlAr的情况,另外一种情况只需要把序列和询问都倒过来再求一遍。
  • 先设 t o i to_i toi表示 A i A_i Ai左边第一个比 A i A_i Ai大的数的位置,可以用单调队列来求。
  • 接下来维护一个单调栈,满足单调不下降,
  • 对于当前的 A i A_i Ai,分两种情况:
  • 一、 A i A_i Ai大于等于栈顶元素
  • 此时以 t o i to_i toi右边且在栈中的元素为左端点的区间最优右端点为 i i i
  • 首先要满足在 t o i to_i toi右边,因为 t o i to_i toi就会成为最大值,占据右端点的位置,不合题意;
  • 其次要满足还在栈中,若不在栈中,则一定是因为它右边出现了比它小的元素才把它弹出的,那么它就不会是最小值了,也不合题意。
  • 用一个下标与该单调栈一一对应的线段树维护这个过程,区间修改一段的右端点为 i i i
  • 然后把 A i A_i Ai压入栈顶。
  • 二、 A i A_i Ai大于等于栈顶元素
  • 不断弹栈,每弹出一个元素就把它加入用另一个与原序列下标一一对应的线段树,然后把第一棵线段树对应位置清空(不然会WA),
  • 最后还是把 A i A_i Ai压入栈顶。
  • 考虑答案怎么求?
  • 把询问按右端点排序,当 A R A_{R} AR被加入栈中之后,在两棵线段树上分别查询 [ L , R ] [L,R] [L,R]对应的区间最大值,更新答案。
  • 注意 L L L在第一棵线段树中不一定存在对应的下标,取它右边最小的那个即可。
  • 时间复杂度 O ( ( Q + N ) log ⁡ 2 N ) O((Q+N)log_2N) O((Q+N)log2N).

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define N 50010
int n, Q, a[N];
int to[N], qu[N], ans[N];
int f[N * 4], g[N * 4], bz[N * 4];
struct node {
	int l, r, id;
}q[N];
int cmp(node x, node y) {
	return x.r < y.r;
}
void change(int v, int l, int r, int x, int y, int c) {
	if(l == x && r == y) {
		f[v] = c - qu[l] + 1;
		bz[v] = c;
	}
	else {
		int mid = (l + r) / 2;
		if(bz[v]) {
			bz[v * 2] = bz[v * 2 + 1] = bz[v];
			f[v * 2] = bz[v] - qu[l] + 1;
			f[v * 2 + 1] = bz[v] - qu[mid + 1] + 1;
			bz[v] = 0;
		}
		if(y <= mid) change(v * 2, l, mid, x, y, c);
		else if(x > mid) change(v * 2 + 1, mid + 1, r, x, y, c);
		else {
			change(v * 2, l, mid, x, mid, c);
			change(v * 2 + 1, mid + 1, r, mid + 1, y, c);
		}
		f[v] = max(f[v * 2], f[v * 2 + 1]);
	}
}
int find(int v, int l, int r, int x, int y) {
	if(l == x && r == y) return f[v];
	int mid = (l + r) / 2, s;
	if(bz[v]) {
		bz[v * 2] = bz[v * 2 + 1] = bz[v];
		f[v * 2] = bz[v] - qu[l] + 1;
		f[v * 2 + 1] = bz[v] - qu[mid + 1] + 1;
		bz[v] = 0;
	}
	if(y <= mid) s = find(v * 2, l, mid, x, y);
	else if(x > mid) s = find(v * 2 + 1, mid + 1, r, x, y);
	else s = max(find(v * 2, l, mid, x, mid), find(v * 2 + 1, mid + 1, r, mid + 1, y));
	f[v] = max(f[v * 2], f[v * 2 + 1]);
	return s;
}
void add(int v, int l, int r, int x, int c) {
	if(l == r) g[v] = max(g[v], c);
	else {
		int mid = (l + r) / 2;
		if(x <= mid) add(v * 2, l, mid, x, c);
		else add(v * 2 + 1, mid + 1, r, x, c);
		g[v] = max(g[v * 2], g[v * 2 + 1]);
	}
}
int get(int v, int l, int r, int x, int y) {
	if(l == x && r == y) return g[v];
	int mid = (l + r) / 2;
	if(y <= mid) return get(v * 2, l, mid, x, y);
	if(x > mid)  return get(v * 2 + 1, mid + 1, r, x, y);
	return max(get(v * 2, l, mid, x, mid), get(v * 2 + 1, mid + 1, r, mid + 1, y));
}
void solve() {
	sort(q + 1, q + Q + 1, cmp);
	memset(to, 0, sizeof(to));
	memset(f, 0, sizeof(f));
	memset(g, 0, sizeof(g));
	memset(bz, 0, sizeof(bz));
	memset(qu, 0, sizeof(qu));
	for(int i = 1; i <= n; i++)	{
		while(qu[0] && a[i] >= a[qu[qu[0]]]) qu[0]--;
		to[i] = qu[qu[0]];
		qu[++qu[0]] = i;
	}
	memset(qu, 0, sizeof(qu));
	int j = 1;
	for(int i = 1; i <= n; i++) {
		if(a[i] >= a[qu[qu[0]]]) {
			int t = 0, l = 1, r = qu[0];
			while(l <= r) {
				int mid = (l + r) / 2;
				if(qu[mid] > to[i]) t = mid, r = mid - 1; else l = mid + 1;
			}
			if(t && qu[0]) change(1, 1, n, t, qu[0], i);
			qu[++qu[0]] = i;
		}
		else {
			while(qu[0] && a[i] < a[qu[qu[0]]]) {
				int s = find(1, 1, n, qu[0], qu[0]);
				add(1, 1, n, qu[qu[0]], s);
				change(1, 1, n, qu[0], qu[0], 0);
				qu[0]--;
			}
			qu[++qu[0]] = i;
		}
		while(j <= Q && q[j].r == i) {
			ans[q[j].id] = max(ans[q[j].id], get(1, 1, n, q[j].l, q[j].r));
			int t = 0, l = 1, r = qu[0];
			while(l <= r) {
				int mid = (l + r) / 2;
				if(qu[mid] >= q[j].l) t = mid, r = mid - 1; else l = mid + 1;
			}
			ans[q[j].id] = max(ans[q[j].id], find(1, 1, n, t, qu[0]));
			j++;
		}
	}
}
int main() {
	int i, j, k;
	scanf("%d", &n);
	for(i = 1; i <= n; i++) scanf("%d", &a[i]);
	scanf("%d", &Q);
	for(i = 1; i <= Q; i++) scanf("%d%d", &q[i].l, &q[i].r), q[i].id = i;
	solve();
	for(i = 1; i <= n / 2; i++) swap(a[i], a[n - i + 1]);
	for(i = 1; i <= Q; i++) swap(q[i].l, q[i].r), q[i].l = n - q[i].l + 1, q[i].r = n - q[i].r + 1;
	solve();
	for(i = 1; i <= Q; i++) printf("%d
", max(ans[i], 1));
	return 0;
}
哈哈哈哈哈哈哈哈哈哈
原文地址:https://www.cnblogs.com/LZA119/p/13910032.html