8.28 ISN

题意

给定一个(N) 个数的序列 (A),如果序列 (A) 不是非降序的,你需要在其中选择一个数删掉,
不断重复这个操作直到序列 (A) 非降。求有多少种不同的删数方案。注意:删掉的数的集
合相同,但是删数的顺序不同,视作不同的删数方案。

答案对(10^9+7)取模

(Nleq 2000,A_i leq 2000)


解法

这题十分巧妙,结合了容斥和DP

首先用DP求出这个序列中所有的一定长度的非降序列以及它们的个数

(f[i][j])为长度为i,末尾为(a[j])的非降序列的数目,那么转移就十分显然了

[f[i][j]=sum f[i - 1][k] (k <j,a[k]<a[j]) ]

可以看出,上面的式子要求和,还与位置与权值大小有关,可以考虑用树状数组进行优化。

控制每次转移时树状数组中的元素都在当前枚举位置之前,保证位置的合法

枚举长度,长度每次改变时清空树状数组,把上一长度求出的(dp)值赋在树状数组中,达到转移的目的

这样就能求出(c[i])了 ((c[i])为长度为i的非降序列的个数)

如果不考虑非法情况,最后的答案

[Ans=sum_{i=1}^n c[i]*(n-i)! ]

用容斥原理去除非法情况

如果用上面的方式算出答案,在什么情况下会出现重复呢?

我们在计算剩下((n-i))个数的删除顺序时,有可能在某一个时刻序列已经是非降的了,按照题意应该停止;但是我们没有停止。就是这一部分构成了非法的情况

如何去掉这种非法情况呢?

在形成长度为i的非降序列之前,我们还要删掉一个数:如果在删掉这个数之前,整个序列就已经是非降的了,那么这一种情况代表的方案就是所有非法的情况

由于非降序列删去任意一个数仍是非降序列,所以这个删去的数有((i+1))种取值,也就是会贡献((i+1))个非法序列

又因为所有长度为(i)的非降序列一定包含在长度为(i+1)的非降序列中

也就是说,只要存在长度为(i+1)的非降序列,就一定有对于长度为(i)的不合法情况

我们可以先构成一个长度为(i+1)的非降序列,即(c[i+1]*(n-i-1)!)

当然,在构成长度为(i+1)非降序列时,也会有不合法的情况,但是反正是要求不合法的情况,这种不合法的情况也应该包括进去(我在说什么@#$@#%#^)

所以答案为

[Ans=sum_{i=1}^n c[i]*(n-i)!-(i+1)*c[i+1]*(n-i-1)! ]


代码

#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

const int N = 2500;
const int mod = 1e9 + 7;

int n, rng;

int a[N], o[N];
int fac[N], sum[N], f[N][N];

struct BIT {
	
	int c[N];
	
	void clear() {
		memset(c, 0, sizeof c);
	}
	
	void insert(int p, int v) {
		if (!p || !v)	return;
		for (; p <= rng; p += p & -p)	(c[p] += v) %= mod;
	}	
	
	int query(int p) {
		int res = 0;
		for (; p; p -= p & -p)	(res += c[p]) %= mod;
		return res;	
	}
} bit;

void config() {
	fac[0] = 1;
	for (int i = 1; i <= n; ++i)	fac[i] = 1LL * fac[i - 1] * i % mod;
	
	for (int i = 1; i <= n; ++i)	f[1][i] = 1;
	
	for (int i = 2; i <= n; ++i) {
		bit.clear();
		for (int j = i - 1; j <= n; ++j) {
			if (j >= i)	f[i][j] = bit.query(a[j]);
			bit.insert(a[j], f[i - 1][j]);
		}
	}
	
	for (int i = 1; i <= n; ++i)
		for (int j = i; j <= n; ++j)	sum[i] = (sum[i] + f[i][j]) % mod;
}

int main() {
	
	freopen("strong.in", "r", stdin);
	freopen("strong.out", "w", stdout);
	
	scanf("%d", &n);
	
	for (int i = 1; i <= n; ++i)	scanf("%d", a + i);
	rng = *max_element(a + 1, a + n + 1);
	
	config();
	
	long long ans = 0;
	for (int i = 1; i <= n; ++i)
		ans = (ans + 1LL * sum[i] * fac[n - i] % mod - 1LL * sum[i + 1] * fac[n - i - 1] % mod * (i + 1) % mod + mod) % mod;
	
	printf("%lld
", ans);
		
	return 0;
}
原文地址:https://www.cnblogs.com/VeniVidiVici/p/11424836.html