「51Nod 1601」完全图的最小生成树计数 「Trie」

题意

给定(n)个带权点,第(i)个点的权值为(w_i),任意两点间都有边,边权为两端点权的异或值,求最小生成树边权和,以及方案数(mod 10^9 + 7)

(n leq 10^5,W = max(w_i) leq 2^{30})

题解

考虑按位贪心,我们从高到低考虑二进制第k位。每次把当前点集(S)分成第(k)位为(0)和第(k)位为(1)的两个集合,记为(S_0, S_1)

我们递归下去把这两个集合连成生成树,然后再找一条最小的跨集合的边把这两个集合连通。

考虑这么做为啥对:假设有两条跨集合的边,我删去一条,树变成两个部分。然后任意找到一条集合内部边使集合(S)连通(既然有跨集合的边存在,我们一定能找到这样的一条边),这样显然更优。

然后考虑问题:找到(xin S_0,yin S_1,x ext{ xor } y)最小。

这个用类似线段树合并的方法:每次两个结点同时往下走,尽量往一边走。如果能同时往(0/1)走,都走一遍,复杂度是对的,每次合并复杂度是子树大小。考虑trie树上一个点只有(O(log W))个祖先,一共只有(O(n log W))个结点,所以复杂度(O(n log ^2 W))

我们再来考虑方案。叶子结点时假设大小为(n),也就是说(n)个点都是这个权值,生成树的方案数(n^{n-2})(由prufer序列得)。非叶子结点时,方案是分成的两个集合的方案乘最后连边方案。连边会对应trie树上多对叶子((u, v))(这些对结点异或起来都是最小的),若叶子(u)上放的数个数用(cnt[u])表示,连边方案就是(sum_{(u,v)} cnt[u]*cnt[v])

P.S.:快速幂写错了调了好久,差评

#include <algorithm>
#include <cstdio>
using namespace std;
typedef long long ll;
char gc() {
	static char buf[1 << 20], * S, * T;
	if(S == T) {
		T = (S = buf) + fread(buf, 1, 1 << 20, stdin);
		if(S == T) return EOF;
	}
	return *S ++;
}
template<typename T> void read(T &x) {
	x = 0; char c = gc(); bool bo = 0;
	for(; c > '9' || c < '0'; c = gc()) bo |= c == '-';
	for(; c <= '9' && c >= '0'; c = gc()) x = x * 10 + (c & 15);
	if(bo) x = -x;
}
const int N = 1e5 + 10;
const int mo = 1e9 + 7;
int n, id = 1, ch[N * 30][2], cnt[N * 30], w[N * 30];
void insert(int x) {
	int u = 1;
	for(int i = 29; ~ i; i --) {
		int y = x >> i & 1;
		if(!ch[u][y]) {
			ch[u][y] = ++ id;
			w[id] = y << i;
		}
		u = ch[u][y];
	}
	cnt[u] ++;
}
ll ans;
int ans2, tot2, tot = 1;
int qpow(int a, int b) {
	int ans = 1;
	for(; b >= 1; b >>= 1, a = (ll) a * a % mo)
		if(b & 1) ans = (ll) ans * a % mo;
	return ans;
}
void merge(int u, int v, int now) {
	now ^= w[u] ^ w[v];
	if(cnt[u] && cnt[v]) {
		if(now < ans2) { ans2 = now; tot2 = 0; }
		if(now == ans2) tot2 = (tot2 + (ll) cnt[u] * cnt[v]) % mo;
		return ;
	}
	bool tag = 0;
	if(ch[u][0] && ch[v][0]) merge(ch[u][0], ch[v][0], now), tag = 1;
	if(ch[u][1] && ch[v][1]) merge(ch[u][1], ch[v][1], now), tag = 1;
	if(tag) return ;
	if(ch[u][0] && ch[v][1]) merge(ch[u][0], ch[v][1], now);
	if(ch[u][1] && ch[v][0]) merge(ch[u][1], ch[v][0], now);
}
bool solve(int u) {
	if(!u) return 0;
	if(cnt[u]) {
		if(cnt[u] > 2) tot = (ll) tot * qpow(cnt[u], cnt[u] - 2) % mo;
		return 1;
	}
	bool s = solve(ch[u][1]) & solve(ch[u][0]);
	if(s) {
		ans2 = 2e9 + 10; tot2 = 1;
		merge(ch[u][0], ch[u][1], 0);
		ans += ans2; tot = (ll) tot * tot2 % mo;
	}
	return 1;
}
int main() {
	read(n);
	for(int i = 1; i <= n; i ++) {
		int x; read(x); insert(x);
	}
	solve(1);
	printf("%lld
%d
", ans, tot);
	return 0;
}
原文地址:https://www.cnblogs.com/hongzy/p/11655899.html