hdu6059
题意
给定数组 (A) ,问有多少对下标 ((i, j, k)) 满足 (i < j < k) 且 ((A[i] xor A[j]) < (A[j] xor A[k])) 。
分析
首先建一棵字典树,从高到低位插入所有数字(长度要相同,所以前面不足用 (0) 补),在插入的过程中计算对于每个 (k) 前面有多少个 (j) 可以配对(也就是在前面插入的值中寻找),只要将当前位取反就能找到有多少个 (j) 与之对应(可以用一个 (cnt) 数组记录每一位分别为 (0) 和 (1) 的次数)。
查询时,从头开始删除,每删除一次(一是要去标记一下,而是去掉这个数作为 (k) 的影响),再去查询对应的数,我们求的实际是对于每个 (A[i]) 有几个 (A[j] A[k]) 与之对应 。在去计算答案的时候前面的标记就有作用了,已经标记作为 (j) 的与后面的 (k) 产生的配对要减掉,因为要满足 (i < j) ,前面标记过的 (j < i) ,所以前面的 (j) 与 (k) 的配对是无效的。
大致的意思就是对 (i) 去寻找 (k) ,然后删掉不满足条件的 (j) 。
建议结合代码画图理解一下。
code
#include<bits/stdc++.h>
typedef long long ll;
using namespace std;
const int MAXN = 2e6 + 10;
int n;
int a[MAXN];
int root, L;
int nxt[MAXN][2], cnt[MAXN][2], has[MAXN];
ll sum[MAXN];
ll ans;
int newnode() {
nxt[L][0] = nxt[L][1] = 0;
return L++;
}
void init() {
L = 1;
root = newnode();
memset(sum, 0, sizeof sum);
memset(cnt, 0, sizeof cnt);
memset(has, 0, sizeof has);
}
void insert(int x, int k) {
int tp[32], c = 0;
memset(tp, 0, sizeof tp);
while(x) {
tp[c++] = x % 2;
x >>= 1;
}
int now = root;
for(int i = 30; i >= 0; i--) {
int d = tp[i];
if(!nxt[now][d]) nxt[now][d] = newnode();
now = nxt[now][d];
cnt[i][d]++;
sum[now] += k * cnt[i][d ^ 1];
has[now] += k;
}
}
void query(int x) {
int tp[32], c = 0;
memset(tp, 0, sizeof tp);
while(x) {
tp[c++] = x % 2;
x >>= 1;
}
int now = root;
for(int i = 30; i >= 0; i--) {
int d = tp[i];
int tmp = nxt[now][d ^ 1];
if(tmp) {
ans += sum[tmp] - 1LL * has[tmp] * cnt[i][d];
}
now = nxt[now][d];
if(!now) break;
}
}
int main() {
int T;
scanf("%d", &T);
while(T--) {
init();
scanf("%d", &n);
for(int i = 0; i < n; i++) {
scanf("%d", &a[i]);
insert(a[i], 1);
}
ans = 0;
memset(cnt, 0, sizeof cnt);
for(int i = 0; i < n; i++) {
insert(a[i], -1);
query(a[i]);
}
printf("%lld
", ans);
}
return 0;
}