Wannafly Camp 2020 Day 2B 萨博的方程式

给定 (n) 个数 (m_i),求 ((x_1,x_2,...,x_n)) 的个数,使得 (x_1 xor x_2 xor ... xor x_n = k),且 (0 leq x_i leq m_i)

Solution

从最高位开始看起,毫无疑问,如果 (m_i) 的某一位是 (0),那么 (x_i) 的这一位只能填 (0),所以只有那些 (m_i) 最高位是 (1) 的才具有选择权。

考虑从最高位数起,哪一位 (pos) 开始,存在一个 (i) 使得 (x_i eq m_i),很显然这个 (pos) 是有范围的,它一定是从最高位开始往下的一段区间,因为如果某一位上,(m_i) 这一位的异或和 ( eq k) 的这一位,更低的位就可以扔掉了。

假设从第 (pos) 位开始,存在一个 (i) 使得 (x_i eq m_i),我们需要统计所有满足这个条件的答案,不妨把这个部分的贡献称作第 (pos) 位的贡献。

(f[i][j]) 表示前 (i)(m) 的当前位是 (1) 的数中,选择了 (j)(x) 的当前位是 (1)(i-j) 个是 (0) 的方案数,那么

[f[i][j]=f[i-1][j-1]cdot (x_i mod 2^{pos} + 1) + f[i-1][j]cdot 2^{pos} ]

考虑如何统计第 (pos) 位的贡献,假设这位 (1) 的个数为 (cnt),那么 (f[cnt][j]) 答案的贡献是 (f[cnt][j]/2^{pos}),当 (j)(k) 该位的奇偶性相同时产生。

#include <bits/stdc++.h>
using namespace std;

#define int long long
const int N = 105;
const int mod = 1e9+7;
inline void exgcd(int a,int b,int &x,int &y) {
    if(!b) {
        x=1,y=0;
        return;
    }
    exgcd(b,a%b,x,y);
    int t=x;
    x=y,y=t-(a/b)*y;
}

inline int inv(int a,int b) {
    int x,y;
    return exgcd(a,b,x,y),(x%b+b)%b;
}

int n,k,m[N],f[N][N];

int solve(int pos) {
    if(pos<0) return 1; //!
    int ret=0,cnt=0;
    memset(f,0,sizeof f);
    f[0][0]=1;
    for(int i=1;i<=n;i++) {
        if((m[i]>>pos)&1) {
            ++cnt;
            f[cnt][0]=f[cnt-1][0]*(1<<pos)%mod; //!
            for(int j=1;j<=cnt;j++) {
                f[cnt][j]=f[cnt-1][j-1]*(m[i]%(1<<pos)+1)
                        +f[cnt-1][j]*(1<<pos);
                f[cnt][j]%=mod;
            }
        }
        else {
            for(int j=0;j<=cnt;j++) {
                f[cnt][j]=f[cnt][j]*(m[i]+1); //!
                f[cnt][j]%=mod;
            }
        }
    }
    int r=inv(1<<pos,mod);
    for(int j=(k>>pos&1);j<cnt;j+=2) {
        ret+=f[cnt][j]*r;
        ret%=mod;
    }
    if((cnt&1) == ((k>>pos)&1)) {
        for(int i=1;i<=n;i++) {
            if(m[i]>>pos&1) m[i]^=(1<<pos);
        }
        return (solve(pos-1) + ret)%mod;
    }
    else return ret;
}

signed main() {
    ios::sync_with_stdio(false);
    while(cin>>n>>k) {
        for(int i=1;i<=n;i++) cin>>m[i];
        cout<<solve(31)<<endl;
    }
}

原文地址:https://www.cnblogs.com/mollnn/p/12336422.html