2017 Multi-University Training Contest

题解:

官方题解太简略了orz

具体实现的方式其实有很多

问题就在于确定A[j]以后,如何找符合条件的A[i]

这里其实就是要提前预处理好

我是倒序插入点的,所以要沿着A[k]爬树,找符合的A[i]

如果发现A[i]与A[k]的第p位不同,比如A[k]位1,A[i]为0,那么所有的在i右边的第p位为0的数就都可以充当A[j]

所以实际上就需要求出有多少点对(i, j),满足这个条件。

不妨用可持久化的思想考虑这个过程

倒序插入A[i]时,我们就能统计出来A[i]的第p位为0(或者为1)时,所有在i右边的第p位为0(或者为1)的数有多少个

但是,问题在于我们需要删除结点

这个过程就要倒着想

如果删除A[i]

1、对于删除的那条字典树的链,链上每个点减少的贡献为 “那个结点的子树大小”

2、对于非链上的点,如果这个点和A[i]相应的第p位相同,那么它减少的贡献也是“这个结点的子树大小”

但注意,1情况对应的子树大小实际上是要减1的,因为被删除了一个结点。

我们用一个数组记录第p位为0或1时删除了几次,就可以处理第二种情况

但是第一种情况是比较特殊的,所以我们对每个结点都记录一下它上次被删除是哪一次

这样就可以做了

#include <iostream>
#include <cstdio>
#include <cstring>
#include <vector>
#include <queue>
#include <cstdlib>
using namespace std;
const int maxn = 5e5 + 200;
typedef long long LL;
struct Node{
    Node* ch[2];
    LL num, ans, Mv;
}pool[maxn*31], *null;
int tot, a[maxn], tt = 0;
LL Minus[32][2], Plus[32][2];
inline Node* newnode(){
    Node* x = &pool[tot++];
    x->ch[0] = x->ch[1] = null;
    x->num = x->ans = x->Mv = 0;
    return x;
}
void pre(){
    null = newnode();
    null->ch[0] = null->ch[1] = null;
    null->num = 0;
}
inline void Insert(Node* root, int x){
    Node* u = root;
    for(int i = 30; i >= 0; i--){
        int c = (x&(1<<i)) ? 1 : 0;
        if(u->ch[c] == null){
            u->ch[c] = newnode();
        }
        u->num++;
        u->ch[c]->ans += Plus[i][c];
        Plus[i][c]++;
        u = u->ch[c];
    }
    u->num++;
}

inline void Erase(Node* root, int x){
    Node* u = root;
    for(int i = 30; i >= 0; i--){
        int c = (x&(1<<i)) ? 1 : 0;
        u->num--;
        Minus[i][c]++;
        u->ch[c]->ans -= (Minus[i][c] - u->ch[c]->Mv - 1)*u->ch[c]->num;
        u->ch[c]->ans -= u->ch[c]->num - 1;
        u->ch[c]->Mv = Minus[i][c];
        u = u->ch[c];
    }
    u->num--;
}

inline LL Find(Node* root, int x){
    LL ans = 0;
    Node* u = root;
    for(int i = 30; i >= 0; i--){
        int c = (x&(1<<i)) ? 1 : 0;
        LL v = u->ch[c^1]->Mv;
        LL rnum = u->ch[c^1]->ans - (Minus[i][c^1]-v)*u->ch[c^1]->num;
        ans += rnum;
        u = u->ch[c];
    }
    return ans;
}

int main()
{
    int T; cin>>T;
    for(; T; T--){
        int n; cin>>n;
        LL ans = 0;
        tot = 0; pre();
        Node* root = newnode();
        memset(Minus, 0, sizeof(Minus));
        memset(Plus, 0, sizeof(Plus));
        for(int i = 1; i <= n; i++) scanf("%d", &a[i]);
        for(int i = n-1; i >= 1; i--) Insert(root, a[i]);
        for(int i = n; i >= 3; i--){
            ans += Find(root, a[i]);
            Erase(root, a[i-1]);
        }
        cout<<ans<<endl;
    }
}
原文地址:https://www.cnblogs.com/Saurus/p/7290186.html