[JOI2018] Snake Escaping

首先暴力是考虑对每个询问枚举问号填 1 还是 0,然后把每种可能统计到答案中,复杂度是 (O(2^n imes q)) 的。

然后我们可以观察出一些性质:

比如 1?1,他可能的情况是 111 或者 101,他们在数值上是不连续的,如果能把他们转化成连续的就好做了。

这里就有一个非常巧妙地转化:

还是以 1?1 为例,我们把它看成一个坐标 ((1,?,1)),那么可能的情况就是 ((1,1,1))((1,0,1))

画出坐标系,不难发现他们在立方体中是连续的一段。

我们可以把它推广到多维,记 (sum_{(x_1,x_2,cdots,x_n)}=displaystyle sum_{i_1=0}^{x_1} sum_{i_2=0}^{x_2} cdots sum_{i_n=0}^{x_n} a_{(i_1,i_2,cdots,i_n)})

这个东西就是一个高维前缀和,可以预处理,我们记录 (S)1? 的位置,枚举 1 构成的集合的子集,然后容斥原理计算答案。

同理,我们也可以记录 0 出现的位置,进行容斥计算答案。

因为 10? 的个数至少有一个小于等于 (lfloor dfrac{n}{3} floor),所以我们统计个数,寻找对应的做法即可。

时间复杂度 (O(2^{lfloor frac{n}{3} floor} imes q))

#include <bits/stdc++.h>
#define reg register
#define fi first
#define se second
#define mp std::make_pair
#define pb push_back
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
int rd()
{
    reg int x=0,f=0;
    reg char ch=getchar();
    while(!isdigit(ch)) (ch=='-')&&(f=1),ch=getchar();
    while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
    return f?-x:x;  
}
const int MAXN=1100000;
int n,q;
int pre[MAXN],suf[MAXN],bit[MAXN],val[MAXN];
char s[MAXN];
void work()
{
    n=rd(),q=rd();
    scanf("%s",s);
    for(reg int i=1;i<=(1<<n)-1;++i) bit[i]=bit[i>>1]+(i&1);
    for(reg int i=0;i<=(1<<n)-1;++i) val[i]=pre[i]=suf[i]=s[i]-'0';
    for(reg int i=1;i<=(1<<n)-1;i<<=1) for(reg int j=i;j<=(1<<n)-1;j=(j+1)|i)
        pre[j]+=pre[j^i],suf[j^i]+=suf[j];
    while(q--)
    {
        scanf("%s",s);
        int ans=0,cnt1=0,cnt2=0,cnt3=0,bit1=0,bit2=0,bit3=0;
        for(reg int i=0;i<n;++i)
        {
            if(s[i]=='0') ++cnt1,bit1|=1<<(n-i-1);
            if(s[i]=='1') ++cnt2,bit2|=1<<(n-i-1);
            if(s[i]=='?') ++cnt3,bit3|=1<<(n-i-1);
        }
        if(cnt1<=cnt2&&cnt1<=cnt3)
        {
            int tmp=bit1;
            do
            {
                if(bit[tmp]&1) ans-=suf[tmp|bit2];
                else ans+=suf[tmp|bit2];
                tmp=(tmp-1)&bit1;
            } while(tmp!=bit1);
        }
        else if(cnt2<=cnt1&&cnt2<=cnt3)
        {
            int tmp=bit2;
            do
            {
                if(bit[tmp^bit2]&1) ans-=pre[tmp|bit3];
                else ans+=pre[tmp|bit3];
                tmp=(tmp-1)&bit2;
            } while(tmp!=bit2);
        }
        else
        {
            int tmp=bit3;
            do
            {
                ans+=val[tmp|bit2];
                tmp=(tmp-1)&bit3;
            } while(tmp!=bit3);
        }
        printf("%d
",ans);
    }
}
int main()
{
    int _=1;
    // _=rd();
    while(_--) work();
    return 0;   
}
原文地址:https://www.cnblogs.com/Lonely-233/p/15039779.html