Luogu P6477 [NOI Online #2 提高组]子序列问题

(large{题目链接})
(\)

题意:

给定一个长度为(n)的正整数序列,定义函数(f_{l,r})表示在下标在(left[l,r ight])的子区间中不同整数的个数。
求:(sum limits^{n}_{l=1} sum limits ^{n}_{r=l}fleft( l,r ight)^{2} left(mod 1e9 + 7 ight))
(1 leq n leq 10^6)
(\)

思路:

首先看到(10^9)的值域,而且关心的只是数值相等不相等,与具体值无关,先离散化一下。
我们枚举左端点(l),考虑当左端点为(l)的区间对答案的贡献,把这些贡献全部加在一起就是最终的答案。
那么题目就变成求 (sum limits _{i = 1} ^ {n} f(l,i)^2)
因为(n)的范围是(10^6),显然要找到一种方法能够维护答案。
对于(left[l,n ight])中出现过的数(x),设它在(left[l,n ight])出现的最左位置为(pos_x)。记(t_i)(f(l,i))的值。
考虑倒序循环(l),那么左端点由(l+1)变为(l)的时候,会发生两种事。

1.(t_l,t_{l+1},...,t_{pos_x-1})都加1。

2.(pos_x)变为l。

那么所需要解决的问题就变为了:

1.支持区间修改。

2.求区间的平方和。

可以用线段树维护。如果区间加上(k),那么平方和变为:

[left( a_{l}+k ight) ^{2}+left( a_{l+1}+k ight) ^{2}+ldots +left( a_{r}+k ight) ^{2} ]

[= a^{2}_{l}+2ka_{l} + k^{2} + a^{2}_{l+1}+2ka_{l+1} + k^{2} +...+ a^{2}_{r}+2ka_{r} + k^{2} ]

[= left( a^{2}_{l}+a^{2}_{l+1}+ldots +a^{2}_{r} ight) + 2k(a_{l}+a_{l+1}+ldots +a_{r}) + (r- l+ 1) imes k ^ 2 ]

维护区间和和区间平方和即可。
(\)

代码:

#include <bits/stdc++.h>
#define ls (x << 1)
#define rs (x << 1 | 1)
using namespace std;

typedef long long ll;

const int N = 1e6 + 5;
const int p = 1e9 + 7; 

int n, a[N], pos[N];
struct Node {
	int id, val;
}b[N];

int read() {
	int x = 0;
	char c = getchar();
	for (; !isdigit(c); c = getchar());
	for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
	return x;
}

bool cmp(Node x, Node y) { return x.val < y.val; }

struct Segment_tree {
	int tl[N << 2], tr[N << 2];
	ll t[N << 2], lz[N << 2], s[N << 2];
	
	void build(int x, int l, int r) {
		tl[x] = l, tr[x] = r;
		if (l == r) return;
		int mid = (l + r) >> 1;
		build(ls, l, mid);
		build(rs, mid + 1, r);
	}
	
	void up(int x) {
		s[x] = s[ls] + s[rs];
		if (s[x] > p) s[x] -= p;
		t[x] = t[ls] + t[rs];
		if (t[x] > p) t[x] -= p;
	}
	
	void down(int x) {
		if (!lz[x]) return;

    	        s[ls] = (s[ls] + 2 * lz[x] * t[ls] % p + (tr[ls] - tl[ls] + 1) * lz[x] * lz[x] % p) % p;
    	        t[ls] = (t[ls] + (tr[ls] - tl[ls] + 1) * lz[x] % p) % p;
    	        lz[ls] += lz[x];
    	
    	        s[rs] = (s[rs] + 2 * lz[x] * t[rs] % p + (tr[rs] - tl[rs] + 1) * lz[x] * lz[x] % p) % p;
    	        t[rs] = (t[rs] + (tr[rs] - tl[rs] + 1) * lz[x] % p) % p;
    	        lz[rs] += lz[x];
    	
    	        lz[x] = 0;
	}
	
	void update(int x, int l, int r, ll k) {
		if (l <= tl[x] && r >= tr[x]) {
        	    s[x] = (s[x] + 2 * k * t[x] % p + (tr[x] - tl[x] + 1) * k * k % p) % p;
        	    t[x] = (t[x] + (tr[x] - tl[x] + 1) * k % p) % p;
        	    lz[x] += k;
        	    return;
		}
		down(x);
		int mid = (tl[x] + tr[x]) >> 1;
		if (l <= mid) update(ls, l, r, k);
		if (r >= mid + 1) update(rs, l, r, k);
		up(x);
	}
	
	ll query(int x, int l, int r) {
		if (l <= tl[x] && r >= tr[x]) return s[x];
		ll ret = 0;
		int mid = (tl[x] + tr[x]) >> 1;
		if (l <= mid) ret = query(ls, l, r);
		if (r >= mid + 1) ret = (ret + query(rs, l, r)) % p;
		return ret;
	}
}T;

int main() {
	n = read();
	for (int i = 1; i <= n; ++i) b[i].id = i;
	for (int i = 1; i <= n; ++i) b[i].val = read();
	sort(b + 1, b + 1 + n, cmp);
	int cnt = 0;
	b[0].val = b[1].val - 1;
	for (int i = 1; i <= n; ++i) b[i].val == b[i - 1].val ? a[b[i].id] = cnt : a[b[i].id] = ++cnt;
	for (int i = 1; i <= cnt; ++i) pos[i] = n + 1;
	T.build(1, 1, n);
	ll ans = 0;
	for (int i = n; i >= 1; --i) {
		T.update(1, i, pos[a[i]] - 1, 1);
		ans = (ans + T.query(1, i, n)) % p;
		pos[a[i]] = i;
	}
	printf("%lld
", ans);
	return 0;
} 
原文地址:https://www.cnblogs.com/Miraclys/p/12775282.html