codeforces 665E Beautiful Subarrays

字典树。

#include<cstdio>
#include<cstring>
#include<cmath>
#include<vector>
#include<map>
#include<stack>
#include<queue>
#include<string>
#include<algorithm>
using namespace std;

const int maxn=1000000+10;
const int maxm=40;
struct Node
{
    int f;
    int son[2];
    int num;
}node[20*maxn];
int tot;

int n,a[maxn],k;
long long ans;
int r[maxm];


int add(int x)
{
    node[tot].f=x;
    node[tot].num=0;
    node[tot].son[0]=-1;
    node[tot].son[1]=-1;
    tot++;
    return tot-1;
}

long long get(int x)
{
    int h=0,base[maxm]; memset(base,0,sizeof base);
    while(x) base[h++]=x%2,x=x/2;

    long long res=0;

    int p=0;
    for(int i=maxm-5;i>=0;i--)
    {
        if(r[i]==0)
        {
            if(base[i]==0)
            {
                if(node[p].son[1]!=-1)
                    res=res+(long long)node[node[p].son[1]].num;
                p=node[p].son[0];

                if(p!=-1&&i==0) res=res+node[p].num;
                if(p==-1) break;
            }
            else
            {
                if(node[p].son[0]!=-1)
                    res=res+(long long)node[node[p].son[0]].num;
                p=node[p].son[1];
                if(p!=-1&&i==0) res=res+node[p].num;
                if(p==-1) break;
            }
        }
        else if(r[i]==1)
        {
            if(base[i]==0)
            {
                p=node[p].son[1];
                if(p!=-1&&i==0) res=res+node[p].num;
                if(p==-1) break;
            }
            else
            {
                p=node[p].son[0];
                if(p!=-1&&i==0) res=res+node[p].num;
                if(p==-1) break;
            }
        }
    }

    return res;
}

void update(int x)
{
    int h=0,base[maxm]; memset(base,0,sizeof base);
    while(x) base[h++]=x%2,x=x/2;

    int p=0;
    for(int i=maxm-5;i>=0;i--)
    {
        if(node[p].son[base[i]]==-1)
            node[p].son[base[i]]=add(base[i]);
        p=node[p].son[base[i]];
        node[p].num++;
    }
}

void init()
{
    ans=0;
    tot=0;
    add(-1);// root 0
    for(int i=1;i<=maxm-4;i++) node[i-1].son[0]=add(0);
    for(int i=1;i<tot;i++) node[i].num=1;
    memset(r,0,sizeof r);
    int v=0; while(k) r[v++]=k%2,k=k/2;
}

int main()
{
    scanf("%d%d",&n,&k);
    for(int i=1;i<=n;i++) scanf("%d",&a[i]);

    init();

    int c=0;
    for(int i=1;i<=n;i++)
    {
        c=c^a[i];

        ans=ans+get(c);
        update(c);
    }

    printf("%lld
",ans);

    return 0;
}
原文地址:https://www.cnblogs.com/zufezzt/p/5630759.html