Codeforces 997E Good Subsegments (线段树)

题目链接

https://codeforces.com/contest/997/problem/E

题解

经典题,鸽了 159 天终于看明白题解了。。

考虑一个区间是连续的等价于这个区间内的 ((max-min)-(r-l)=0),否则该值 (gt 0).
那么我们考虑从小到大枚举右端点 (r),当 (r) 变为 ((r+1)) 时,对于每个 (l),上述值的变化形式就是区间加,以当前的 (r) 为右端点的好的区间个数就等于 (0) 的个数,这个可以通过维护最小值和最小值的个数来实现。直接线段树维护就可以得到以 (r) 为右端点的好的区间个数。

但是我们要求的是一个区间内的好的子区间个数,也就是当 (r) 在一个区间中时某个 (l) 区间内 (0) 的个数之和。于是我们要做的就是在线段树上维护最小值为 (0) 的个数的历史和。

这个的维护方法就是,考虑再在每个节点维护 (0) 的个数的历史和以及一个 (tag) 标记表示当前区间的最小值要被算多少次,然后每次给 ([1,r]) 这个区间拆成的 (O(log)) 个线段树区间中 最小值为 (0) 的节点的 (tag) 增加 (1),然后访问区间时下放标记同时维护历史和即可。

时间复杂度 (O(n+qlog n)).

代码

#include<bits/stdc++.h>
#define llong long long
#define mkpr make_pair
#define iter iterator
#define riter reversed_iterator
#define y1 Lorem_ipsum_dolor
using namespace std;

inline int read()
{
	int x = 0,f = 1; char ch = getchar();
	for(;!isdigit(ch);ch=getchar()) {if(ch=='-') f = -1;}
	for(; isdigit(ch);ch=getchar()) {x = x*10+ch-48;}
	return x*f;
}

const int mxN = 1.2e5;
struct SgTNode
{
	int mn,cnt,tag; llong sum,tag2;
} sgt[mxN*4+3];
void pushdown(int u)
{
	if(sgt[u].tag)
	{
		sgt[u<<1].mn += sgt[u].tag; sgt[u<<1].tag += sgt[u].tag;
		sgt[u<<1|1].mn += sgt[u].tag; sgt[u<<1|1].tag += sgt[u].tag;
		sgt[u].tag = 0;
	}
	if(sgt[u].tag2)
	{
		if(sgt[u<<1].mn==sgt[u].mn) {sgt[u<<1].sum += sgt[u<<1].cnt*sgt[u].tag2; sgt[u<<1].tag2 += sgt[u].tag2;}
		if(sgt[u<<1|1].mn==sgt[u].mn) {sgt[u<<1|1].sum += sgt[u<<1|1].cnt*sgt[u].tag2; sgt[u<<1|1].tag2 += sgt[u].tag2;}
		sgt[u].tag2 = 0ll;
	}
}
void pushup(int u)
{
	sgt[u].mn = min(sgt[u<<1].mn,sgt[u<<1|1].mn);
	sgt[u].cnt = (sgt[u<<1].mn==sgt[u].mn?sgt[u<<1].cnt:0)+(sgt[u<<1|1].mn==sgt[u].mn?sgt[u<<1|1].cnt:0);
	sgt[u].sum = sgt[u<<1].sum+sgt[u<<1|1].sum;
}
void build(int u,int le,int ri)
{
	if(le==ri) {sgt[u].cnt = 1; return;}
	int mid = (le+ri)>>1;
	build(u<<1,le,mid); build(u<<1|1,mid+1,ri);
	pushup(u);
}
void add(int u,int le,int ri,int lb,int rb,int x)
{
//	if(u==1) {printf("add %d %d %d
",lb,rb,x);}
	if(le>=lb&&ri<=rb) {sgt[u].mn += x,sgt[u].tag += x; return;}
	int mid = (le+ri)>>1; pushdown(u);
	if(lb<=mid) {add(u<<1,le,mid,lb,rb,x);} if(rb>mid) {add(u<<1|1,mid+1,ri,lb,rb,x);}
	pushup(u);
}
void modify(int u,int le,int ri,int lb,int rb)
{
	if(le>=lb&&ri<=rb)
	{
		if(sgt[u].mn==0) {sgt[u].sum += sgt[u].cnt,sgt[u].tag2++;}
		return;
	}
	int mid = (le+ri)>>1; pushdown(u);
	if(lb<=mid) {modify(u<<1,le,mid,lb,rb);} if(rb>mid) {modify(u<<1|1,mid+1,ri,lb,rb);}
	pushup(u);
}
llong query(int u,int le,int ri,int lb,int rb)
{
	if(le>=lb&&ri<=rb) {return sgt[u].sum;}
	int mid = (le+ri)>>1; pushdown(u); llong ret = 0ll;
	if(lb<=mid) {ret += query(u<<1,le,mid,lb,rb);} if(rb>mid) {ret += query(u<<1|1,mid+1,ri,lb,rb);}
	pushup(u); return ret;
}

struct Query
{
	int l,r,id;
	bool operator <(const Query &arg) const {return r<arg.r;}
} qr[mxN+3];
llong ans[mxN+3];
int stk1[mxN+3],stk2[mxN+3];
int a[mxN+3];
int n,q,tp1,tp2;

int main()
{
	n = read();
	for(int i=1; i<=n; i++) a[i] = read();
	q = read();
	for(int i=1; i<=q; i++) scanf("%d%d",&qr[i].l,&qr[i].r),qr[i].id = i;
	sort(qr+1,qr+q+1);
	build(1,1,n);
	for(int i=1,j=1; i<=n; i++)
	{
//		printf("i=%d
",i);
		while(tp1>0&&a[i]>a[stk1[tp1]])
		{
			add(1,1,n,stk1[tp1-1]+1,stk1[tp1],a[i]-a[stk1[tp1]]);
			tp1--;
		}
		stk1[++tp1] = i;
		while(tp2>0&&a[i]<a[stk2[tp2]])
		{
			add(1,1,n,stk2[tp2-1]+1,stk2[tp2],+a[stk2[tp2]]-a[i]);
			tp2--;
		}
		stk2[++tp2] = i;
		if(i>1) {add(1,1,n,1,i-1,-1);}
		modify(1,1,n,1,i);
		while(j<=q&&qr[j].r==i)
		{
			ans[qr[j].id] = query(1,1,n,qr[j].l,qr[j].r);
			j++;
		}
	}
	for(int i=1; i<=q; i++) printf("%I64d
",ans[i]);
	return 0;
}
原文地址:https://www.cnblogs.com/suncongbo/p/12654108.html