题解「CF241B Friends」

链接

CodeForces 241B

题意

给定一个长度为 (N) 的序列,求前 (k) 大的 (a_i operatorname{xor} a_j) 之和,答案对 (10^9+7) 取模。

其中 (a_i operatorname{xor} a_j)(a_j operatorname{xor} a_i) 被看成是同一对。

( exttt{Data Range:}1leq nleq 10^6,1leq kleq frac{n(n-1)}{2})

约定

以下称 (a_i operatorname{xor} a_j) 的值为异或值。

(b_v[i..j]) 为值 (v)二进制意义 下第 (i) 位至第 (j) 位的取值,如 (b_3[0..1]=11_2)(b_5[1..2]=10_2)(b_5[0,2]=101_2)(下标 (2) 即表示二进制意义下)。

二进制意义下的最高位为 (mx)

题解

此题是 [十二省联考2019]异或粽子 的加强版,解法吊打 ( exttt{std})

提供了一种时间复杂度与 (k) 无关的解法,时间复杂度为 (O(nomega^2+nlog n)),其中 (omega) 为二进制意义下 (a_i) 的最高位,是一个(leq 31) 的常数

首先可以想到,要求出前 (k) 大的异或和,必须先求出第 (k) 大的异或和。

有一种很暴力的做法,枚举每个 (a_i,a_j),求出第 (k) 大的异或值。无疑可以优化,枚举每一个 (a_i),将上述过程放至 ( exttt{trie}) 上,从高位贪心,用类似权值线段树求第 (k) 大的方式在 ( exttt{trie}) 上二分即可。

更加具体地说,我们将一棵 ( exttt{trie}) 分为 (n) 层,第 (i) 层的指针即 (a_i) 所对应的 (a_j) 的前缀。从最高位开始枚举第 (k) 大的 (a_i operatorname{xor} a_j) 的每一位上取 (0/1),在枚举到第 (x) 位时,在每一层上计算 (a_i operatorname{xor} a_j) 的第 (x) 位为 (1) 的方案数,记这个数为 (cnt)。若 (cntgeq k),则说明这一位必须取 (1),否则所求会小于第 (k) 大。若 (cnt < k),则说明这一位必须取 (0) ,否则所求会大于第 (k) 大。在 (n)(trie) 上对应更新指针即可,并统计第 (x) 位的取值,更新第 (k) 大的异或值在 (x) 上的取值即可。

求出第 (k) 大的 (a_i operatorname{xor} a_j) 后,记其为 (val)。若 (val)(x) 位上为 (0),则若存在一个异或值 (t),令(b_{val}[x+1..mx]=b_t[x+1..mx]),仅在 (t) 的第 (x) 位上为 (1)(t) 一定大于 (val) ,属于前 (k) 大的取值。

考虑计算这部分值的贡献。仍然是从高位贪心,分成 (n) 层。每层按照 (val) 的值在 ( exttt{trie}) 上递归,枚举值 (val) 在二进制意义上为 (0) 的那些位,统计仅在这一位上取 (1) ,第 (x+1-mx) 位上与 (val) 在二进制意义下取值相同的那些异或值对答案造成的贡献。

由于即使确定了 (a_i)(a_j) 也是在一棵子树中,对一棵子树统计答案,不太好做。可以在初始时直接对 (a_i) 排序,这样 ( exttt{trie}) 的一棵子树即对应了 (a) 上一段连续的区间。统计贡献时按位统计,对每一位上的 (0/1) 分别维护一个前缀和,这样就可以了。

Show the Code

#include<cstdio>
#include<algorithm>
typedef long long ll;
const ll mod=1e9+7;
int n,tot=0,rt=0;
int a[50005],tmp[50005];
int sum[35][50005][2];
int ch[1000005][2],size[1000005],l[1000005],r[1000005];
inline ll read() {
	register ll x=0,f=1;register char s=getchar();
	while(s>'9'||s<'0') {if(s=='-') f=-1;s=getchar();}
	while(s>='0'&&s<='9') {x=x*10+s-'0';s=getchar();}
	return x*f;
}
inline ll Add(ll x,ll y) {return ((x+y)%mod+mod)%mod;}
inline ll pow(ll x,ll p) {
	ll res=1;
	for(;p;p>>=1) {if(p&1) res=res*x%mod;x=x*x%mod;}
	return res;
}
inline void insert(int val,int id) {
	if(!rt) rt=++tot,l[rt]=1,r[rt]=n;
	int p=rt;++size[rt];
	for(register int d=31;d>=0;--d) {
		int &nxt=ch[p][val>>d&1];
		if(!nxt) {nxt=++tot;l[nxt]=id;} 
		p=nxt;r[p]=id;++size[p];
	}
}
inline int getVal(ll &k) {
	int res=0;
	for(register int i=1;i<=n;++i) tmp[i]=rt;
	for(register int d=31;d>=0;--d) {//若当前取1的个数为cnt,cnt>k 
		int x=0;
		ll cnt=0;
		for(register int i=1;i<=n;++i) cnt+=size[ch[tmp[i]][(a[i]>>d&1)^1]];
		if(cnt>=k) res|=(1<<d),x=1;//cnt<k
		else k-=cnt,x=0;
		for(register int i=1;i<=n;++i) tmp[i]=ch[tmp[i]][(a[i]>>d&1)^x];
	}
	return res;
}
inline ll ask(int l,int r,int val) {
	if(!l||!r) return 0;
	ll res=0;
	for(register int d=31;d>=0;--d) res=Add(res,(ll)(sum[d][r][(val>>d&1)^1]-sum[d][l-1][(val>>d&1)^1])*((1ll<<d)%mod)%mod);
	return res;
}
inline ll getSum(ll k) {
	int val=getVal(k);
	ll res=val*k%mod;
	for(register int i=1;i<=n;++i) tmp[i]=rt;
	for(register int d=31;d>=0;--d) {
		int x=val>>d&1;
		if(!x) {for(register int i=1;i<=n;++i) res=Add(res,ask(l[ch[tmp[i]][(a[i]>>d&1)^1]],r[ch[tmp[i]][(a[i]>>d&1)^1]],a[i]));}
		for(register int i=1;i<=n;++i) tmp[i]=ch[tmp[i]][(a[i]>>d&1)^x];
	}
	return res;	
}
signed main() {
	n=read();
	ll k=read()*2ll;//25*1e8=2.5*1e9
	for(register int i=1;i<=n;++i) a[i]=read();
	std::sort(a+1,a+1+n);
	for(register int i=1;i<=n;++i) insert(a[i],i);
	for(register int i=1;i<=n;++i) {
		for(register int d=31;d>=0;--d) {
			sum[d][i][0]=sum[d][i-1][0];
			sum[d][i][1]=sum[d][i-1][1];
			++sum[d][i][a[i]>>d&1]; 
		}
	} 
	printf("%lld
",getSum(k)*pow(2,mod-2)%mod);
	return 0;
}
原文地址:https://www.cnblogs.com/tommy0103/p/13050433.html