HDU 6059

题意略。

思路:我们要想令 A[i] ^ A[j] < A[j] ^ A[k](i < j < k),由于A[i]和A[k]都要 ^ A[j],所以我们只需研究一下i,k这两个数之间的关系即可。

我们按位来考虑这两个数之间的关系,可以想到,A[i]和A[k]这两个数的最高不相同位决定了A[i] ^ A[j]与A[j] ^ A[k]的大小关系:

(下面用high[ ]来表示A[i]和A[k]这两个数的最高不相同位)

1.如果high[i] = 1,high[k] = 0,那么A[j]的这一位应该是1。

2.如果high[i] = 0,high[k] = 1,那么A[j]的这一位应该是0。

那么,对于每一个A[k],我们枚举它与前面数的最高不相同位来计算它对最后答案的贡献。

现在看看怎么来达到这个目的:

1.我们需要知道A[i]的个数c1,对A[i]的约束即是在最高不相同位的更高位上与A[k]相同,在最高不相同位上与A[k]相异,这个个数我们可以用字典树来维护。

2.我们需要知道A[j]的个数c2,A[j]需要满足的条件是在最高不同位上与A[k]相异,这个我们可以用一个二维数组Cnt[31][2]来维护,

里面记录着从A[1]~A[k - 1]这k - 1个数中在第i位为j的项的个数。

那A[k]的贡献是不是就是c1 * c2了呢?并不是。

有两个不合理的条件:

1.c2中包含了c1,也就是说c1 * c2中有可能有同一个数选了两次的情况。

2.c1 * c2只保证了i < k && j < k,未能保证i < j这个条件。

为了去掉1中的不合理因素,我们只需要减去c1即可。

为了去掉2中的不合理因素,我们可以用illegal[ ]来记录字典树上这个结点的不合理数。怎么记录呢?

每当我们插入这个结点的时候,当前Cnt[ ][ ]数组中存的值都是在当前结点之前出现过的,它们都是当前结点的不合理数。

所以,A[k]的贡献是c1 * c2 - c1 - illegal[cur_node]。

详见代码:

#include<bits/stdc++.h>
#define maxn 500000 * 31
using namespace std;
typedef long long LL;

int Cnt[31][2];

struct Trie{
    int relation[maxn][2],info[maxn];
    LL illegal[maxn];
    
    int root,cnt;
    int newnode(){
        relation[cnt][0] = relation[cnt][1] = -1;
        info[cnt] = illegal[cnt] = 0;
        return cnt++;
    }
    void init(){
        cnt = 0;
        root = newnode();
    }
    void insert(int x){
        int cur = root;
        ++info[cur];
        for(int i = 29;i >= 0;--i){
            if(relation[cur][(x>>i) & 1] == -1)
                relation[cur][(x>>i) & 1] = newnode();
            cur = relation[cur][(x>>i) & 1];
            illegal[cur] += (Cnt[i][(x>>i) & 1]);
            ++Cnt[i][(x>>i) & 1];
            ++info[cur];
        }
    }
    LL query(int x){
        LL ret = 0;
        int cur = root;
        for(int i = 29;i >= 0;--i){
            int numb = ((x>>i) & 1),another = 1 - numb;
            int idx = relation[cur][another];
            LL c = info[idx];
            LL temp = c * (Cnt[i][another]) - c - illegal[idx];
            ret += temp;
            cur = relation[cur][numb];
            if(cur == -1) break;
        }
        return ret;
    }
};

Trie trie;

int main(){
    int T;
    scanf("%d",&T);
    while(T--){
        LL ans = 0;
        int n;
        scanf("%d",&n);
        trie.init();
        memset(Cnt,0,sizeof(Cnt));
        for(int i = 0;i < n;++i){
            int temp;
            scanf("%d",&temp);
            ans += trie.query(temp);
            trie.insert(temp);
        }
        printf("%lld
",ans);
    }
    return 0;
}

/*
1
4
9 8 7 3
*/
原文地址:https://www.cnblogs.com/tiberius/p/8727939.html