数组

题意

给定一个序列,支持单点修改,查询有多少个子区间满足区间内元素互不相同。

题解

我们记数组$last_i$表示上一个与第$i$个元素相同的位置,所以一定有$last_i<i$。

一个区间$[L,R]$合法当且仅当$last_i<L(iin [L,R])$。

所以对于一个固定的右端点$R$,它对答案的贡献一定是$R-maxspace last_i(ileq R)$。

所以每次询问的答案就是$frac {n(n+1)}{2}-sumlimits_{i=1}^{n} maxspace last_jspace(0<jleq i)$

不难发现$maxspace last_jspace(0<jleq i)$是一个前缀的最大值,我们只需要用线段树维护单调栈即可。

具体做法是每次通过左右儿子来更新当前节点时分类讨论。

首先由于从前向后维护单调栈,那么左侧部分$L$一定会直接对答案产生贡献,而右半部分$R$从中间拆开成$ls,rs$,分两种情况讨论。

1、$ls$最大值小于等于$L$的最大值,那么$ls$一定在单调栈中可以完全被$L$所代替,那么$L$的对最终和的贡献只有$L$的最大值乘以$ls$的区间长度,递归处理$rs$即可。

2、若$ls$的最大值已然大于$L$的最大值了,那么单调栈中所有$rs$的部分一定会完整地贡献给答案,然后递归处理$ls$即可。注意,这里计算$Ans_{rs}$的时候不能直接使用$Ans_{rs}$,而是要使用$Ans_{R}-Ans_{ls}$来更新,原因是,$rs$的答案并非直接贡献给了$R$,具体的可以看下面这张图红色部分提到的一种可能性。

在修改时,对于每一种颜色再开一个$set$,方便每次修改时直接求颜色内该位置的前驱后继,找出哪些地方的$last_i$改变了更新即可。

复杂度为$O(nlog^2n)$。

#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<set>
#define LL long long
#define M 200020
#define mid ((l+r)>>1)
using namespace std;
int read(){
	int nm=0,fh=1; char cw=getchar();
	for(;!isdigit(cw);cw=getchar()) if(cw=='-') fh=-fh;
	for(;isdigit(cw);cw=getchar()) nm=nm*10+(cw-'0');
	return nm*fh;
}
set<int>col[M];
int n,m,val[M],p[M<<2],c[M],v[M];
LL sum[M<<2],tot;
LL calc(int x,int l,int r,int maxn){
	if(p[x]<=maxn) return (LL)(r-l+1)*(LL)maxn; if(l==r) return sum[x];
	if(p[x<<1]>maxn) return calc(x<<1,l,mid,maxn)+sum[x]-sum[x<<1];
	else return (LL)(mid-l+1)*(LL)maxn+calc(x<<1|1,mid+1,r,maxn);
}
void pushup(int x,int l,int r){
	p[x]=max(p[x<<1],p[x<<1|1]);
	sum[x]=sum[x<<1]+calc(x<<1|1,mid+1,r,p[x<<1]);
}
void build(int x,int l,int r){
	if(l==r){sum[x]=p[x]=v[l];return;}
	build(x<<1,l,mid),build(x<<1|1,mid+1,r),pushup(x,l,r);
}
void change(int x,int l,int r,int pos,int num){
	if(l==r){sum[x]=p[x]=num;return;}
	if(pos<=mid) change(x<<1,l,mid,pos,num);
	else change(x<<1|1,mid+1,r,pos,num); pushup(x,l,r);	
}
int main(){
	n=read(),tot=(LL)n*(LL)(n+1),tot>>=1;
	set<int>::iterator it,pre,suf;
	for(int i=1;i<=n;i++) col[i].insert(0);
	for(int i=1;i<=n;i++){
	    c[i]=read(),pre=col[c[i]].end();
		pre--,v[i]=*pre,col[c[i]].insert(i);
	} build(1,1,n);
	for(int T=read();T;T--){
		if(!read()){printf("%lld
",tot-sum[1]);continue;}
		int pos=read(),num=read();
		suf=pre=col[c[pos]].find(pos),pre--,suf++;
		if(suf!=col[c[pos]].end()) change(1,1,n,*suf,*pre);
		col[c[pos]].erase(pos),col[c[pos]=num].insert(pos);
		suf=pre=col[num].find(pos),pre--,suf++,change(1,1,n,pos,*pre);
		if(suf!=col[num].end()) change(1,1,n,*suf,pos);
	}
	return 0;
}
原文地址:https://www.cnblogs.com/OYJason/p/9724535.html