题解 Nasty Donchik 一道数据结构题

题目大意

题目链接 比赛链接

给定一个长度为(n)的序列(a_1,a_2dots,a_n)。保证(forall i:1leq a_ileq n)。请你求出,序列里有多少三元组((i,j,k)),满足(a[i,j])里的所有数,都在(a[j+1,k])里出现过;且(a[j+1,k])里所有数,都在(a[i,j])里出现过。

(nleq 2 imes 10^5)

本题题解

枚举(k)。对每个(j),维护使三元组((i,j,k))合法的最小的和最大的(i),分别记为( ext{mini}[j], ext{maxi}[j])。那么,当前(k)的三元组数量就是:(sum_{j=1}^{k-1}( ext{maxi}[j]- ext{mini}[j]+1))。考虑分别计算( ext{maxi})的和和( ext{mini})的和。

记每个位置(t)上的数上一次和下一次出现的位置分别为( ext{pre}[t])( ext{nxt}[t]),特别地,如果前面/后面没有相同的数,则( ext{pre}[t]=0)( ext{nxt}[t]=n+1)。那么,我们发现,三元组((i,j,k))合法的充分必要条件是:(max_{t=i}^{j}( ext{nxt}[t])leq k),且(min_{t=j+1}^{k}( ext{pre}[t])geq i)

由此可知,( ext{maxi}[j])就是满足(min_{t=j+1}^{k}( ext{pre}[t])geq i)的最大的(i)( ext{mini}[j])就是满足(max_{t=i}^{j}( ext{nxt[}t])leq k)的最小的(i)

( ext{maxi})比较好维护,他就等于(min_{t=j+1}^{k}( ext{pre}[j]))。当从(k-1)变到(k)时,我们让所有(jin[1,k-1])( ext{maxi}[j])( ext{pre}[k])(min)即可。

考虑( ext{mini})。我们称( ext{nxt}[t]>k)的位置为不合法的,其他位置为合法的。那么对于每个(j)( ext{mini}[j])就相当于(j)前面、最靠近(j)的那个不合法的位置(+1)。特别地,如果(j)本身就不合法,我们认为( ext{mini}[j]=j+1)。从(k-1)变到(k),会使得所有( ext{nxt}[t]=k)的位置,从不合法变成合法。相当于把两段( ext{mini})的区间“合并”起来(令后一段区间的值等于前一段区间的值)。而( ext{nxt}[t]=k)的位置最多只有一个:就是( ext{pre}[k])。所以每次对一段区间执行区间覆盖(或者区间取(min))即可(事实上因为( ext{maxi})要支持的是区间取(min),所以都用区间取(min)反而更好写)。

还有一个要注意的点是,我们要始终保证,( ext{mini}[j]leq ext{maxi}[j]+1),所以对( ext{maxi})(min)的时候,要对( ext{mini})做一样的操作。

总结来说,需要支持区间对一个数取(min),区间求和,可以用吉老师线段树实现。另外,我们还要对一个位置求它前面、最靠近它的不合法的位置,同时要支持单点修改(把某个位置从不合法变为合法),这个可以用线段上二分实现。

时间复杂度(O(nlog n))

参考代码:

#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fst first
#define scd second

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;

namespace Fread{
	const int MAXN=1<<20;
	char buf[MAXN],*S,*T;
	inline char getchar(){
		if(S==T){
			T=(S=buf)+fread(buf,1,MAXN,stdin);
			if(S==T)return EOF;
		}
		return *S++;
	}
}
#ifdef ONLINE_JUDGE
	#define getchar Fread::getchar
#endif
inline int read(){
	int f=1,x=0;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
inline ll readll(){
	ll f=1,x=0;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
/*  ------  by:duyi  ------  */ // dysyn1314
const int MAXN=2e5;
int n;
/*
struct Baoli{
	int a[MAXN+5],val[MAXN+5],val2[MAXN+5];
	int get_nxt0(int p){
		for(int i=p;i<=n+1;++i)if(a[i]==0)return i;
		throw;
	}
	int get_pre0(int p){
		for(int i=p;i>=0;--i)if(a[i]==0)return i;
		throw;
	}
	void set1(int p){
		a[p]=1;
	}
	void init(){
		for(int i=1;i<=n;++i)val[i]=val2[i]=i;
	}
	void modify_min_mxi(int l,int r,int x){
		for(int i=l;i<=r;++i)val[i]=min(val[i],x);
	}
	void modify_min_mni(int l,int r,int x){
		for(int i=l;i<=r;++i)val2[i]=min(val2[i],x);
	}
	int get_sum_mxi(){
		int res=0;
		for(int i=1;i<=n;++i)res+=val[i]*a[i];
		return res;
	}
	int get_sum_mni(){
		int res=0;
		for(int i=1;i<=n;++i)res+=val2[i]*a[i];
		return res;
	}
}T;
*/
class SegmentTree{
private:
	int sz[MAXN*4+5],mx[2][MAXN*4+5],se[2][MAXN*4+5],ct[2][MAXN*4+5];
	ll sum[2][MAXN*4+5];
	void _pu(int p,int *mx,int *se,int *ct,ll *sum){
		sum[p]=sum[p<<1]+sum[p<<1|1];
		if(mx[p<<1]>mx[p<<1|1]){
			mx[p]=mx[p<<1];
			se[p]=max(se[p<<1],mx[p<<1|1]);
			ct[p]=ct[p<<1];
		}
		else if(mx[p<<1]<mx[p<<1|1]){
			mx[p]=mx[p<<1|1];
			se[p]=max(mx[p<<1],se[p<<1|1]);
			ct[p]=ct[p<<1|1];
		}
		else{
			mx[p]=mx[p<<1];
			se[p]=max(se[p<<1],se[p<<1|1]);
			ct[p]=ct[p<<1]+ct[p<<1|1];
		}
	}
	void push_up(int p){
		sz[p]=sz[p<<1]+sz[p<<1|1];
		_pu(p,mx[0],se[0],ct[0],sum[0]);
		_pu(p,mx[1],se[1],ct[1],sum[1]);
	}
	void _pd(int p,int *mx,int *ct,ll *sum){
		if(mx[p]<mx[p<<1]){
			sum[p<<1]-=(ll)ct[p<<1]*(mx[p<<1]-mx[p]);
			mx[p<<1]=mx[p];
		}
		if(mx[p]<mx[p<<1|1]){
			sum[p<<1|1]-=(ll)ct[p<<1|1]*(mx[p<<1|1]-mx[p]);
			mx[p<<1|1]=mx[p];
		}
	}
	void push_down(int p){
		_pd(p,mx[0],ct[0],sum[0]);
		_pd(p,mx[1],ct[1],sum[1]);
	}
	void build(int p,int l,int r){
		if(l==r){
			mx[0][p]=mx[1][p]=l;
			se[0][p]=se[1][p]=-1;
			return;
		}
		int mid=(l+r)>>1;
		build(p<<1,l,mid);
		build(p<<1|1,mid+1,r);
		push_up(p);
	}
	void modify1(int p,int l,int r,int pos){
		if(l==r){
			sz[p]=1;
			ct[0][p]=ct[1][p]=1;
			sum[0][p]=mx[0][p];
			sum[1][p]=mx[1][p];
			return;
		}
		push_down(p);
		int mid=(l+r)>>1;
		if(pos<=mid)modify1(p<<1,l,mid,pos);
		else modify1(p<<1|1,mid+1,r,pos);
		push_up(p);
	}
	int __first0(int p,int l,int r){
		if(l==r){assert(sz[p]==0);return l;}
		push_down(p);
		int mid=(l+r)>>1;
		if(sz[p<<1]<mid-l+1)return __first0(p<<1,l,mid);
		else return __first0(p<<1|1,mid+1,r);
	}
	int _nxt0(int p,int l,int r,int ql,int qr){
		if(ql<=l && qr>=r){
			if(sz[p]==r-l+1)return n+1;
			else return __first0(p,l,r);
		}
		push_down(p);
		int mid=(l+r)>>1,res=n+1;
		if(ql<=mid&&sz[p<<1]<mid-l+1)res=_nxt0(p<<1,l,mid,ql,qr);
		if(res!=n+1)return res;
		if(qr>mid&&sz[p<<1|1]<r-mid)return _nxt0(p<<1|1,mid+1,r,ql,qr);
		else return n+1;
	}
	int __last0(int p,int l,int r){
		if(l==r){assert(sz[p]==0);return l;}
		push_down(p);
		int mid=(l+r)>>1;
		if(sz[p<<1|1]<r-mid)return __last0(p<<1|1,mid+1,r);
		else return __last0(p<<1,l,mid);
	}
	int _pre0(int p,int l,int r,int ql,int qr){
		if(ql<=l && qr>=r){
			if(sz[p]==r-l+1)return 0;
			else return __last0(p,l,r);
		}
		push_down(p);
		int mid=(l+r)>>1,res=0;
		if(qr>mid&&sz[p<<1|1]<r-mid)res=_pre0(p<<1|1,mid+1,r,ql,qr);
		if(res)return res;
		if(ql<=mid&&sz[p<<1]<mid-l+1)return _pre0(p<<1,l,mid,ql,qr);
		else return 0;
	}
	void modify2(int p,int l,int r,int ql,int qr,int x,int t){
		//区间对x取min
		if(x>=mx[t][p])return;
		if(ql<=l && qr>=r && se[t][p]<x){
			sum[t][p]-=(ll)ct[t][p]*(mx[t][p]-x);
			mx[t][p]=x;
			return;
		}
		push_down(p);
		int mid=(l+r)>>1;
		if(ql<=mid)modify2(p<<1,l,mid,ql,qr,x,t);
		if(qr>mid)modify2(p<<1|1,mid+1,r,ql,qr,x,t);
		push_up(p);
	}
public:
	//mxi tree0
	//mni tree1
	void set1(int p){modify1(1,1,n,p);}
	int get_nxt0(int p){
		if(p>n)return n+1;
		if(p<1)return 0;
		return _nxt0(1,1,n,p,n);
	}
	int get_pre0(int p){
		if(p>n)return n+1;
		if(p<1)return 0;
		return _pre0(1,1,n,1,p);
	}
	void modify_min_mxi(int l,int r,int x){
		if(l>r)return;
		modify2(1,1,n,l,r,x,0);
	}
	void modify_min_mni(int l,int r,int x){
		if(l>r)return;
		modify2(1,1,n,l,r,x,1);
	}
	ll get_sum_mxi(){return sum[0][1];}
	ll get_sum_mni(){return sum[1][1];}
	void init(){build(1,1,n);}
}T;
int a[MAXN+5],nxt[MAXN+5],pre[MAXN+5],pos[MAXN+5];
int main(){
	n=read();
	for(int i=1;i<=n;++i){a[i]=read();pre[i]=pos[a[i]];pos[a[i]]=i;}
	for(int i=1;i<=n;++i)pos[i]=n+1;
	for(int i=n;i>=1;--i){nxt[i]=pos[a[i]];pos[a[i]]=i;}
	T.init();
	ll ans=0;
	for(int k=1;k<=n;++k){
		if(pre[k]){
			int x=T.get_nxt0(pre[k]+1)-1;
			//cout<<"* "<<x<<" "<<T.get_pre0(pre[k]-1)<<endl;
			T.modify_min_mni(pre[k],x,T.get_pre0(pre[k]-1));
			T.set1(pre[k]);
		}
		T.modify_min_mxi(1,k-1,pre[k]);
		T.modify_min_mni(1,k-1,pre[k]);
		ans+=T.get_sum_mxi()-T.get_sum_mni();
	}
	cout<<ans<<endl;
	return 0;
}
原文地址:https://www.cnblogs.com/dysyn1314/p/12937520.html