排列计数机

牛客 - 排列计数机

statement

定义一个长为 (k) 的序列 (A_1,A_2,…,A_k) 的权值为:对于所有 (1≤i≤k)(max(A_1,A_2,…,A_i)) 有多少种不同的取值。

给出一个 (1)(n) 的排列 (B_1,B_2,…,B_n),求 (B) 的所有非空子序列的权值的 (m) 次方之和。

答案对 (10^9+7) 取模。

Hints

对于前 (10\%) 的数据,(n≤20)
对于前 (20\%) 的数据,(n≤100)
对于前 (40\%) 的数据,(n≤1000)
对于另外 (20\%) 的数据,(m=1)
对于所有数据,(1≤n≤10^5)(1≤m≤20),保证 (B)(1)(n) 的排列。

solution

先想一个暴力 DP,设 (dp(i, j, k)) 表示到前 (i) 个数,并且第 (i) 个数字必须选,最大值为 (j),权值为 (k) 的方案数,转移:

[dp(i, j, k) = sum_{l < i and pge a_i } dp(l, p, k) + sum_{l < i and p < a_i} dp(l, p, k - 1) ]

可以做到 (mathcal O (n ^ 4) - mathcal O (n ^ 3))

可以发现转移的时候,对于同一个二元组 ((j, k)) 对应的每一个 (i) 总是会在先前的状态原封不动地转移过来,所以可以把原先的第一维度压掉,最后的转移只有:

[egin{align} &dp(i, j) ightarrow dp(A_k, j + 1) & A_k > i\ &dp(i, j) ightarrow dp(i, j) & A_k le i end{align} ]

其中 (A_k) 是当前枚举到的数字。

由于这个 (m) 很小,尝试展开这个次幂的式子,首先知道二项式定理:

[(a + b) ^ m =sum_{k ge 0} {mchoose k} a ^ k b ^ {m - k} ]

这启发我们维护每个最大值对应的每个次幂的信息,由于 (m) 比较小,看起来非常赚。

(f(i, j)) 表示最大值为 (i) ,子序列个数的 (j) 次幂的大小,看下上面 (dp) 的转移。

  • (dp(i, j) leftarrow dp(i, j) A_k le i)

等价于 (f(i, j) leftarrow f(i, j) A_k le i)

  • (dp(A_k, j) leftarrow dp(i, j) A_k > i)

等价于 (f(A_{now}, j) = sum_i sum_{kle j} {jchoose k} f(i, k) A_k > i)

这个 (f) 可以用线段树简单维护一下,支持区间乘二,区间求和,区间赋值即可。

时间复杂度 (mathcal O (nm^2 + nmlog n))

#include <bits/stdc++.h>
#define forn(i,s,t) for(register int i=(s); i<=(t); ++i)
#define form(i,s,t) for(register int i=(s); i>=(t); --i)
#define rep(i,s,t) for(register int i=(s); i<(t); ++i)
using namespace std;
const int N = 1e5 + 3, M = 22, Mod = 1e9 + 7;
struct Mint {
	int res;
	Mint() {}
	Mint(int a) : res(a) {}
	inline friend Mint operator + (Mint A, Mint B) {
		return Mint((A.res + B.res >= Mod) ? (A.res + B.res - Mod) : (A.res + B.res));
	}
	inline friend Mint operator - (Mint A, Mint B) {return A + Mint(Mod - B.res);}
	inline friend Mint operator * (Mint A, Mint B) {return Mint(1ll * A.res * B.res %Mod);}
	inline friend Mint& operator += (Mint& A, Mint B) {return A = A + B;}
	inline friend Mint& operator -= (Mint& A, Mint B) {return A = A - B;}
	inline friend Mint& operator *= (Mint& A, Mint B) {return A = A * B;}
	inline friend Mint q_pow(Mint p, int k = Mod - 2) {
		static Mint res; res = Mint(1);
		for(; k; k >>= 1, p *= p) (k & 1) && (res *= p, 0);
		return res;
	}
	inline friend Mint operator ~ (Mint A) {return q_pow(A);}
};
Mint fac[N], ifac[N];
inline void init(int n) {
	fac[0] = Mint(1);
	forn(i,1,n) fac[i] = fac[i - 1] * Mint(i);
	ifac[n] = ~fac[n];
	form(i,n - 1,0) ifac[i] = ifac[i + 1] * Mint(i + 1);
}
inline Mint C(int n, int r) {return fac[n] * ifac[r] * ifac[n - r];}
Mint pow2[N], F[N][M]; int m;
struct node {
	int tag; Mint cof[M];
};
struct SegTree {
	node val[N << 2];
	inline void up(int p) {forn(i,0,m) val[p].cof[i] = val[p << 1].cof[i] + val[p << 1 | 1].cof[i];}
	inline void opt(int p, int k) {
		val[p].tag += k; forn(i,0,m) val[p].cof[i] *= pow2[k];
	}
	inline void down(int p) {
		opt(p << 1, val[p].tag), opt(p << 1 | 1, val[p].tag), val[p].tag = 0;
	}
	void Upd(int p, int l, int r, int nl, int nr) {
		if(l == nl && nr == r) return opt(p, 1);
		int mid = nl+nr >> 1;
		(val[p].tag) && (down(p), 0);
		if(r <= mid) Upd(p << 1, l, r, nl, mid);
		else if(l > mid) Upd(p << 1 | 1, l, r, mid + 1, nr);
		else Upd(p << 1, l, mid, nl, mid), Upd(p << 1 | 1, mid + 1, r, mid + 1, nr);
		up(p);
	}
	node Qry(int p, int l, int r, int nl, int nr) {
		if(l == nl && nr == r) return val[p];
		int mid = nl+nr >> 1;
		(val[p].tag) && (down(p), 0);
		if(r <= mid) return Qry(p << 1, l, r, nl, mid);
		else if(l > mid) return Qry(p << 1 | 1, l, r, mid + 1, nr);
		else {
			node res;
			node L = Qry(p << 1, l, mid, nl, mid);
			node R = Qry(p << 1 | 1, mid + 1, r, mid + 1, nr);
			forn(i,0,m) res.cof[i] = L.cof[i] + R.cof[i];
			return res;
		}
	}
	void Cov(int p, int l, int r, int pos) {
		if(l == r) {
			forn(i,0,m) val[p].cof[i] = F[pos][i];
			return ;
		}
		int mid = l+r >> 1;
		(val[p].tag) && (down(p), 0);
		if(pos <= mid) Cov(p << 1, l, mid, pos);
		else Cov(p << 1 | 1, mid + 1, r, pos);
		up(p);
	}
}T;
int n, a[N];
int main() {
	scanf("%d%d", &n, &m), init(m);
	pow2[0] = Mint(1);
	forn(i,1,n) scanf("%d", a + i), pow2[i] = pow2[i - 1] + pow2[i - 1];
	forn(i,1,n) {
		static node res;
		T.Upd(1, a[i], n, 1, n), res = T.Qry(1, 1, a[i], 1, n);
		forn(j,0,m) {
			F[a[i]][j] = Mint(1);
			forn(k,0,j) F[a[i]][j] += C(j, k) * res.cof[k];
		}
		T.Cov(1, 1, n, a[i]);
	}
	printf("%d", T.val[1].cof[m]);
	return 0;
}
原文地址:https://www.cnblogs.com/Ax-Dea/p/15026940.html