【洛谷P6477】子序列问题

题目

题目链接:https://www.luogu.com.cn/problem/P6477

题目描述

给定一个长度为 \(n\) 的正整数序列 \(A_1\), \(A_2\), \(\cdots\), \(A_n\)。定义一个函数 \(f(l,r)\) 表示:序列中下标在 \([l,r]\) 范围内的子区间中,不同的整数个数。换句话说,\(f(l,r)\) 就是集合 \(\{A_l,A_{l+1},\cdots,A_r\}\) 的大小,这里的集合是不可重集,即集合中的元素互不相等。

现在,请你求出 \(\sum_{l=1}^n\sum_{r=l}^n (f(l,r))^2\)。由于答案可能很大,请输出答案对 \(10^9 +7\) 取模的结果。

思路

考试时居然先写了这题再写 T1 的。。。主要是 T1 一眼没看出结论。

我们考虑枚举右端点 \(r\),对于第 \(l\) 个数,我们记录 \(f[l]\) 表示 \([l,r]\) 中不同的数字个数。

假设我们已经通过某种奇妙的方法求出了 \(r=i\) 时的 \(f\)。接下来我们求 \(r=i+1\) 时的 \(f\)

  • 如果 \(a[i+1]\) 在之前没出现过,那么显然每一个区间都多了一个不同的数。\(f[1\sim i+1]\) 全部加一。
  • 如果 \(a[i+1]\) 在第 \(j\) 位出现过 \((j<i)\),那么 \([1,j]\) 出现的数不变,\([j+1,i+1]\) 出现的数加一。也就是 \(f[j+1\sim i+1]\) 全部加一。

维护平方?套路性拆开,依旧维护区间平方和、区间和即可。

线段树就可以轻松解决这些问题。

时间复杂度 \(O(n\log n)\)

代码

#include <cstdio>
#include <cctype>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;

const int N=1000010,MOD=1e9+7;
int n,a[N],b[N],last[N];
ll ans;

inline int read()
{
	int d=0; char ch=getchar();
	while (!isdigit(ch)) ch=getchar();
	while (isdigit(ch)) d=(d<<3)+(d<<1)+ch-48,ch=getchar();
	return d;
}

struct SegTree
{
	int l[N*4],r[N*4],len[N*4],sum[N*4],lazy[N*4];
	ll mul[N*4];
	
	void build(int x,int ql,int qr)
	{
		l[x]=ql; r[x]=qr; len[x]=qr-ql+1;
		if (ql==qr) return;
		register int mid=(ql+qr)>>1;
		build(x*2,ql,mid); build(x*2+1,mid+1,qr);
	}
	
	void update(int x,int ql,int qr)
	{
		if (l[x]==ql && r[x]==qr)
		{
			mul[x]=(mul[x]+2LL*sum[x]+len[x])%MOD;
			sum[x]=(sum[x]+len[x])%MOD;
			lazy[x]++;
			return;
		}
		if (lazy[x])
		{
			ll p=lazy[x]; register int lc=x*2,rc=x*2+1;
			lazy[lc]+=p; lazy[rc]+=p;
			mul[lc]=(mul[lc]+2LL*p*sum[lc]+len[lc]*p*p)%MOD;
			sum[lc]=(sum[lc]+len[lc]*p)%MOD;
			mul[rc]=(mul[rc]+2LL*p*sum[rc]+len[rc]*p*p)%MOD;
			sum[rc]=(sum[rc]+len[rc]*p)%MOD;
			lazy[x]=0;
		}
		register int mid=(l[x]+r[x])>>1;
		if (qr<=mid) update(x*2,ql,qr);
		else if (ql>mid) update(x*2+1,ql,qr);
		else update(x*2,ql,mid),update(x*2+1,mid+1,qr);
		sum[x]=(sum[x*2]+sum[x*2+1])%MOD;
		mul[x]=(mul[x*2]+mul[x*2+1])%MOD;
	}
}seg;

int main()
{
	n=read();
	for (register int i=1;i<=n;i++)
		a[i]=b[i]=read();
	sort(b+1,b+1+n);
	register int tot=unique(b+1,b+1+n)-b-1;
	for (register int i=1;i<=n;i++)
		a[i]=lower_bound(b+1,b+1+n,a[i])-b;
	seg.build(1,1,n);
	for (register int i=1;i<=n;i++)
	{
		if (last[a[i]])
			seg.update(1,last[a[i]]+1,i);
		else
			seg.update(1,1,i);
		last[a[i]]=i;
		ans=(ans+seg.mul[1])%MOD;
	}
	printf("%lld",ans);
	return 0;
}
原文地址:https://www.cnblogs.com/stoorz/p/12773799.html