hdu 6059---Kanade's trio(字典树)

题目链接

Problem Description
Give you an array A[1..n],you need to calculate how many tuples (i,j,k) satisfy that (i<j<k) and ((A[i] xor A[j])<(A[j] xor A[k]))

There are T test cases.

1T20

1n5105

0A[i]<230
 
Input
There is only one integer T on first line.

For each test case , the first line consists of one integer n ,and the second line consists of n integers which means the array A[1..n]
 
Output
For each test case , output an integer , which means the answer.
 
Sample Input
1
5
1 2 3 4 5
 
Sample Output
6

题意:输入一个数列a[1]~a[n] ,求有多少个三元组(i,j,k) 满足1<=i<j<k<=n  且  a[i]异或a[j] < a[j]异或a[k]?

思路:对于a[i]与a[k],对于二进制从高位向低位进行判断,如果30位(A[i]<2^30)到25位相同,那么a[j]的这些位不管值是多少不影响异或后 a[i] 与 a[k] 的大小关系,现在第24位不同,那么a[j]的这一位必须和a[i]相同,这样a[k]异或a[j]的值一定大于a[i]异或a[j] ,第23位到第0位不管a[j]取何值不会影响大小关系了。  有上述可以得出我们只需要判断a[i]和a[k]的二进制最高不相同位就行,那么可以用一个二进制的字典树存储这n个数。

       从a[i]~a[n]将a[k]插入字典树中,每次插入时需要记录 当前节点有多少数(num表示)、当前节点对应的a[j]有多少(count表示),用cn[32][2]记录第i位为0和1时的a[j]的个数,所以每次到一个节点时用count+=cn[i][1-t],表示当前的位(0或1),这样可以保证j<k,但是没有保证i<j ;

       接下来将cn[][]清空,从a[1]~a[n]的进行删除,对于a[i]删除,可以保证i<k ,那么可以用count-num*cn[i][t] 保证i<j ;

代码如下:

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
using namespace std;
typedef long long LL;
const int N=5e5+5;
int a[N],p[35],cn[32][2];
struct node
{
    node *son[2];
    int count;
    int num;
    node() { count=0; num=0; son[0]=son[1]=NULL; }
};
node *root;

void add(int x,int v)
{
   node * now=root;
   for(int i=30;i>=0;i--)
   {
       int t=(!!(p[i]&x));
       if(now->son[t]==NULL)  now->son[t]=new node();
       now=now->son[t];
       now->num+=v;
       cn[i][t]++;
       now->count+=v*cn[i][1-t];///当前点对应的j的个数;
   }
}

LL cal(int x)
{
    node * now=root;
    LL sum=0;
    for(int i=30;i>=0;i--)
    {
       int t=(!!(p[i]&x));
       node* bro=now->son[1-t];
       if(bro)
       sum+=bro->count - ((LL)bro->num*(LL)cn[i][t]);
       now=now->son[t];
       if(!now) break;
    }
    return sum;
}

int main()
{
    ///cout << "Hello world!" << endl;
    int T;  cin>>T;
    p[0]=1;
    for(int i=1;i<32;i++) p[i]=p[i-1]<<1;
    while(T--)
    {
       int n;  scanf("%d",&n);
       for(int i=1;i<=n;i++)  scanf("%d",&a[i]);
       root=new node();
       memset(cn,0,sizeof(cn));
       for(int i=1;i<=n;i++)  add(a[i],1);
       memset(cn,0,sizeof(cn));
       LL ans=0;
       for(int i=1;i<n;i++){
          add(a[i],-1);
          ans+=cal(a[i]);
       }
       printf("%lld
",ans);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/chen9510/p/7289433.html