[做题记录-计数相关] [AGC023E] Inversions

题意

一个长度为(n)的排列数组, 每个位置有上限的限制, 求所有合法的排列的逆序数的和。

(n leq 10^5)

题解

经典套路是先计数一下序列的个数然后考虑每一对数对答案的贡献。

考虑从大往小填数, 记(b_i)表示(a_i ge i)的位置个数, 那么合法的排列个数为:

[cnt = prod_i b_i - (n - i) ]

然后现在对于序列上一对位置考虑。不妨设(i < j), (p_i)表示填好以后(i)位置上的数是什么。

(a_i = a_j)时, 显然(p_i >p_j)(p_i < p_j)的情况数相同, 贡献是(frac{cnt}{2})

(a_i < a_j)的时候, 讨论(a_j)的取值。当(a_j in [a_i + 1, a_j])的时候这里肯定是没有贡献的, 那么考虑强行让(a_j = a_i), 那么这样的话会发现(b_i : which i in [a_i + 1, a_j])会减少(1)。那么这里的贡献是:

[frac{cnt}{2} prod _{k = a_i + 1}^{a_j}frac{b_k - (n - k) - 1}{b_k - (n - k)} ]

(a_i > a_j)的时候, 不妨把(i, j)反过来, 讨论就变成了(a_j < a_i), 那么这里的顺序对就和上面情况中逆序对的情况是一样的, 用总方案减去出现顺序对的情况即可。

[cnt - frac{cnt}{2} prod _{k = a_j + 1}^{a_i}frac{b_k - (n - k) - 1}{b_k - (n - k)} ]

然后考虑快速计算, 以(a_i < a_j)为例, 考虑从小往大枚举(a_j), 维护一个全局的数据结构, 每次(a_j)增大的时候全局乘, 查询的时候查位置小于(j)的所有位置的权值和, 然后在(j)位置加入一个(frac{cnt}{2})

/*
	QiuQiu /qq
  ____    _           _                 __                
  / __   (_)         | |               / /                
 | |  | |  _   _   _  | |  _   _       / /    __ _    __ _ 
 | |  | | | | | | | | | | | | | |     / /    / _` |  / _` |
 | |__| | | | | |_| | | | | |_| |    / /    | (_| | | (_| |
  \___\_ |_|  \__,_| |_|  \__, |   /_/      \__, |  \__, |
                            __/ |               | |     | |
                           |___/                |_|     |_|
*/

#include <bits/stdc++.h>

using namespace std;

class Input {
	#define MX 1000000
	private :
		char buf[MX], *p1 = buf, *p2 = buf;
		inline char gc() {
			if(p1 == p2) p2 = (p1 = buf) + fread(buf, 1, MX, stdin);
			return p1 == p2 ? EOF : *(p1 ++);
		}
	public :
		Input() {
			#ifdef Open_File
				freopen("a.in", "r", stdin);
				freopen("a.out", "w", stdout);
			#endif
		}
		template <typename T>
		inline Input& operator >>(T &x) {
			x = 0; int f = 1; char a = gc();
			for(; ! isdigit(a); a = gc()) if(a == '-') f = -1;
			for(; isdigit(a); a = gc()) 
				x = x * 10 + a - '0';
			x *= f;
			return *this;
		}
		inline Input& operator >>(char &ch) {
			while(1) {
				ch = gc();
				if(ch != '
' && ch != ' ') return *this;
			}
		}
		inline Input& operator >>(char *s) {
			int p = 0;
			while(1) {
				s[p] = gc();
				if(s[p] == '
' || s[p] == ' ' || s[p] == EOF) break;
				p ++; 
			}
			s[p] = '';
			return *this;
		}
	#undef MX
} Fin;

class Output {
	#define MX 1000000
	private :
		char ouf[MX], *p1 = ouf, *p2 = ouf;
		char Of[105], *o1 = Of, *o2 = Of;
		void flush() { fwrite(ouf, 1, p2 - p1, stdout); p2 = p1; }
		inline void pc(char ch) {
			* (p2 ++) = ch;
			if(p2 == p1 + MX) flush();
		}
	public :
		template <typename T> 
		inline Output& operator << (T n) {
			if(n < 0) pc('-'), n = -n;
			if(n == 0) pc('0');
			while(n) *(o1 ++) = (n % 10) ^ 48, n /= 10;
			while(o1 != o2) pc(* (--o1));
			return *this; 
		}
		inline Output & operator << (char ch) {
			pc(ch); return *this; 
		}
		inline Output & operator <<(const char *ch) {
			const char *p = ch;
			while( *p != '' ) pc(* p ++);
			return * this;
		}
		~Output() { flush(); } 
	#undef MX
} Fout;

#define cin Fin
#define cout Fout
#define endl '
'

using LL = long long;

inline int log2(unsigned int x);
inline int popcount(unsigned x);
inline int popcount(unsigned long long x);

template <int mod>
class Int {
	private :
		inline int Mod(int x) { return x + ((x >> 31) & mod); } 
		inline int power(int x, int k) {
			int res = 1;
			while(k) {
				if(k & 1) res = 1LL * x * res % mod;
				x = 1LL * x * x % mod; k >>= 1;
			}
			return res;
		}
	public :
		int v;
		Int(int _v = 0) : v(_v) {}
		operator int() { return v; }
		
		inline Int operator =(Int x) { return Int(v = x.v); }
		inline Int operator =(int x) { return Int(v = x); }
		inline Int operator *(Int x) { return Int(1LL * v * x.v % mod); }
		inline Int operator *(int x) { return Int(1LL * v * x % mod); }
		inline Int operator +(Int x) { return Int( Mod(v + x.v - mod) ); }
		inline Int operator +(int x) { return Int( Mod(v + x - mod) ); }
		inline Int operator -(Int x) { return Int( Mod(v - x.v) ); }
		inline Int operator -(int x) { return Int( Mod(v - x) ); }
		inline Int operator ~() { return Int(power(v, mod - 2)); }
		inline Int operator +=(Int x) { return Int(v = Mod(v + x.v - mod)); }
		inline Int operator +=(int x) { return Int(v = Mod(v + x - mod)); }
		inline Int operator -=(Int x) { return Int(v = Mod(v - x.v)); }
		inline Int operator -=(int x) { return Int(v = Mod(v - x)); }
		inline Int operator *=(Int x) { return Int(v = 1LL * v * x.v % mod); }
		inline Int operator *=(int x) { return Int(v = 1LL * v * x % mod); }
		inline Int operator /=(Int x) { return Int(v = v / x.v); }
		inline Int operator /=(int x) { return Int(v = v / x); }
		inline Int operator ^(int k) { return Int(power(v, k)); }
} ;

using mint = Int<(int) (1e9 + 7)>;

const int N = 2e5 + 10;
const mint inv2 = ~ mint(2);

int n;
int a[N], b[N];
mint cnt;

struct Node {
	Node *ls, *rs;
	mint cj, tg;
	int l, r;
	Node() {}
	Node(int _l, int _r) : l(_l), r(_r), cj(0), tg(1), ls(NULL), rs(NULL) {}
	void upd() {
		cj = ls -> cj + rs -> cj;
	}
	void downcj(mint v) { cj *= v; tg *= v; }
	void pushdown() {
		if(tg != 1) {
			ls -> downcj(tg);
			rs -> downcj(tg);
			tg = 1;
		}
	}
	void modify(int pos, mint v) {
		if(l == r) { cj = v; return ; }
		pushdown();
		int mid = (l + r) >> 1;
		if(pos <= mid) ls -> modify(pos, v);
		else rs -> modify(pos, v);
		upd();
	}
	mint qry(int L, int R) {
		if(L <= l && r <= R) return cj;
		pushdown();
		int mid = (l + r) >> 1;
		mint res = 0;
		if(L <= mid) res += ls -> qry(L, R);
		if(R > mid) res += rs -> qry(L, R);
		return res;
	}
	void mul(mint v) { downcj(v); return ; }
} ;

Node *root;

Node *build(int l, int r) {
	Node * x = new Node(l, r);
	if(l == r) return x;
	int mid = (l + r) >> 1;
	x -> ls = build(l, mid);
	x -> rs = build(mid + 1, r);
	return x;
}

using pii = pair<int, int>;

int c[N];
#define lowbit(x) (x & -x)
void upd(int x, int y) {
	for(; x <= n; x += lowbit(x)) c[x] += y;
}
int qry(int x) {
	int ans = 0;
	for(; x; x -= lowbit(x)) ans += c[x];
	return ans;
}

int main() {
	cin >> n;
	for(int i = 1; i <= n; i ++) cin >> a[i];
	for(int i = 1; i <= n; i ++) b[a[i]] ++;
	for(int i = n; i >= 1; i --) b[i] += b[i + 1];
	cnt = 1;
	for(int i = 1; i <= n; i ++) cnt = cnt * (b[i] - (n - i));
	root = build(1, n);
	mint ans = 0;
	static vector<int> lim[N];
	for(int i = 1; i <= n; i ++) lim[a[i]].push_back(i);
	for(int i = 1; i <= n; i ++) {
		mint value = b[i] - (n - i) - 1;
		value = value * (~ (value + 1));
		root -> mul(value);
		for(int j : lim[i]) ans += root -> qry(1, j);
		for(int j : lim[i]) root -> modify(j, cnt * inv2);
		mint t = lim[i].size();
		ans += t * (t - 1) * inv2 * cnt * inv2;
	}
	//cout << ans << endl;
	for(int i = n; i >= 1; i --) {
		ans += cnt * qry(a[i] - 1);
		upd(a[i], 1); 
	}
//	cerr << ans << endl;
	root = build(1, n);
	for(int i = 1; i <= n; i ++) {
		mint value = b[i] - (n - i) - 1;
		value = value * (~ (value + 1));
		root -> mul(value);
		for(int j : lim[i]) ans -= root -> qry(j, n);
		for(int j : lim[i]) root -> modify(j, cnt * inv2);
		//mint t = lim[i].size();
		//ans -= t * (t - 1) * inv2 * cnt * inv2;
	}
	cout << ans << endl;
	return 0;
}

inline int log2(unsigned int x) { return __builtin_ffs(x); }
inline int popcount(unsigned int x) { return __builtin_popcount(x); }
inline int popcount(unsigned long long x) { return __builtin_popcountl(x); }
原文地址:https://www.cnblogs.com/clover4/p/15304569.html