[CF1083D]The Fair Nut’s getting crazy[单调栈+线段树]

题意

给定一个长度为 (n) 的序列 ({a_i})。你需要从该序列中选出两个非空的子段,这两个子段满足

  • 两个子段非包含关系。
  • 两个子段存在交。
  • 位于两个子段交中的元素在每个子段中只能出现一次。

求共有多少种不同的子段选择方案。输出总方案数对 (10^9 + 7) 取模后的结果。

需要注意的是,选择子段 ([a, b])([c, d]) 与选择子段 ([c, d])([a, b]) 被视为是相同的两种方案。

(1 leq n leq 10^5, -10^9 leq a_i leq 10^9)

分析

  • 考虑枚举一个区间 ([b,c]) 作为交,记录 (L_i,R_i) 表示距离 (i) 最近的和 (i) 颜色相同的位置。

  • 有: (ain[maxlimits_{i=b}^c{L_i},b),din(c,minlimits_{i=b}^c{R_i}])

  • 记录可以取到的左端点的最小值(满足区间中不存在两个相同的数) (pos)(mi, mx) 分别表示 ([j,i])(R) 的极小值和 (L) 的极大值。

  • 考虑从左到右枚举交区间的右端点 (i) ,用单调栈维护每个位置的 (mi, mx) 。容易得到以 (i) 为交区间的右端点的方案数为 (sum_{j=pos}^i(mi_j-i)(j-mx_j)​),拆开然后用线段树分别维护。

  • 总时间复杂度为 (O(nlogn))

代码

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
#define go(u) for(int i = head[u], v = e[i].to; i; i=e[i].lst, v=e[i].to)
#define rep(i, a, b) for(int i = a; i <= b; ++i)
#define pb push_back
#define re(x) memset(x, 0, sizeof x)
inline int gi() {
    int x = 0,f = 1;
    char ch = getchar();
    while(!isdigit(ch)) { if(ch == '-') f = -1; ch = getchar();}
    while(isdigit(ch)) { x = (x << 3) + (x << 1) + ch - 48; ch = getchar();}
    return x * f;
}
template <typename T> inline void Max(T &a, T b){if(a < b) a = b;}
template <typename T> inline void Min(T &a, T b){if(a > b) a = b;}
const int N = 1e5 + 7, mod = 1e9 + 7;
int n, vc;
LL ans;
int lst[N], L[N], R[N], V[N], a[N];
int st1[N], st2[N], tp1, tp2;
#define Ls o << 1
#define Rs (o << 1 | 1)
LL s1(int n) {
	return 1ll * n * (n + 1) / 2;
}
LL ami[N << 2], amx[N << 2];
struct data {
	LL mi, mx, smi, tm;
	data operator +(const data &rhs) const {
		return (data){ (mi + rhs.mi) % mod, (mx + rhs.mx) % mod, (smi + rhs.smi) % mod, (tm + rhs.tm) % mod};
	}
}t[N << 2];
void add(LL &a, LL b) {
	a += b;if(a >= mod) a -= mod;
}
void stmi(int l, int r, int o, int v) {
	add(ami[o], v);
	add(t[o].tm, 1ll * v * t[o].mx % mod);
	add(t[o].mi, 1ll * (r - l + 1) * v % mod);
	add(t[o].smi, (s1(r) - s1(l - 1)) % mod * v % mod);
}
void stmx(int l, int r, int o, int v) {
	add(amx[o], v);
	add(t[o].tm, 1ll * v * t[o].mi % mod);
	add(t[o].mx, 1ll * (r - l + 1) * v % mod);
}
void pushdown(int l, int r, int o) {
	int mid = l + r >> 1;
	if(ami[o]) {
		stmi(l, mid, Ls, ami[o]);
		stmi(mid + 1, r, Rs, ami[o]);
	}
	if(amx[o]) {
		stmx(l, mid, Ls, amx[o]);
		stmx(mid + 1, r, Rs, amx[o]);
	}
	ami[o] = amx[o] = 0;
}
void pushup(int o) {
	t[o] = t[Ls] + t[Rs];
}
void modify(int L, int R, int l, int r, int o, int v, int opt) {
	if(L <= l && r <= R) {
		if(!opt) stmi(l, r, o, v);
		else stmx(l, r, o, v);
		return;
	}
	pushdown(l, r, o);int mid = l + r >> 1;
	if(L <= mid) modify(L, R, l, mid, Ls, v, opt);
	if(R > mid)  modify(L, R, mid + 1, r, Rs, v, opt);
	pushup(o);
}
data query(int L, int R, int l, int r, int o) {
	if(L <= l && r <= R) return t[o];
	pushdown(l, r, o);int mid = l + r >> 1;
	if(R <= mid) return query(L, R, l, mid, Ls);
	if(L > mid)  return query(L, R, mid + 1, r, Rs);
	return query(L, R, l, mid, Ls) + query(L, R, mid + 1, r, Rs);
}
int main() {
	n = gi();
	rep(i, 1, n) a[i] = gi(), V[i] = a[i];
	sort(V + 1, V + 1 + n);
	vc = unique(V + 1, V + 1 + n) - V - 1;
	rep(i, 1, n) a[i] = lower_bound(V + 1, V + 1 + vc, a[i]) - V;
	rep(i, 1, n) {
		L[i] = lst[a[i]] + 1;
		lst[a[i]] = i;
	}
	rep(i, 1, vc) lst[i] = n + 1;
	for(int i = n; i; --i) {
		R[i] = lst[a[i]] - 1;
		lst[a[i]] = i;
	}
	for(int i = 1, gg = 1; i <= n; ++i) {
		for(; tp1 && L[i] >= L[st1[tp1]]; --tp1) {
			modify(st1[tp1 - 1] + 1, st1[tp1], 1, n, 1, mod - L[st1[tp1]], 1);
		}
		modify(st1[tp1] + 1, i, 1, n, 1, L[i], 1);
		st1[++tp1] = i;
		for(; tp2 && R[i] <= R[st2[tp2]]; --tp2) {
			modify(st2[tp2 - 1] + 1, st2[tp2], 1, n, 1, mod - R[st2[tp2]], 0);
		}
		modify(st2[tp2] + 1, i, 1, n, 1, R[i], 0);
		st2[++tp2] = i;
		
		Max(gg, L[i]);
		data res = query(gg, i, 1, n, 1);
		LL tmp = ((res.smi + i * res.mx % mod - res.tm - (s1(i) - s1(gg - 1)) % mod * i % mod) % mod + mod) % mod;
		add(ans, tmp);
	}
	printf("%lld
", ans);
	return 0;
}
原文地址:https://www.cnblogs.com/yqgAKIOI/p/10212225.html