Codeforces 691E题解 DP+矩阵快速幂

题面

传送门:http://codeforces.com/problemset/problem/691/E
E. Xor-sequences
time limit per test3 seconds
memory limit per test256 megabytes
inputstandard input
outputstandard output
You are given n integers a1,  a2,  …,  an.

A sequence of integers x1,  x2,  …,  xk is called a “xor-sequence” if for every 1  ≤  i  ≤  k - 1 the number of ones in the binary representation of the number xi xi  +  1’s is a multiple of 3 and for all 1 ≤ i ≤ k. The symbol is used for the binary exclusive or operation.

How many “xor-sequences” of length k exist? Output the answer modulo 109 + 7.

Note if a = [1, 1] and k = 1 then the answer is 2, because you should consider the ones from a as different.

Input
The first line contains two integers n and k (1 ≤ n ≤ 100, 1 ≤ k ≤ 1018) — the number of given integers and the length of the “xor-sequences”.

The second line contains n integers ai (0 ≤ ai ≤ 1018).

Output
Print the only integer c — the number of “xor-sequences” of length k modulo 109 + 7.

Examples
inputCopy
5 2
15 1 2 4 8
outputCopy
13
inputCopy
5 1
15 1 2 4 8
outputCopy
5
题目大意:给定长度为n的序列,从序列中选择k个数(可以重复选择),使得得到的排列满足xi与xi+1异或的二进制中1的个数是3的倍数。问长度为k的满足条件的序列有多少种?

分析:

1.DP方程的推导

dp[i][j]dp[i][j]表示前i位以j结尾的方案数

dp[i][j]=ix=1dp[i1][x]dp[i][j]=∑x=1idp[i−1][x]
(bitcount(a[x](bitcount(a[x] xorxor a[i])moda[i])mod 31)3≡1)
bitcount(x)表示x的二进制中1的个数

2.矩阵快速幂的优化

我们发现,状态转移方程的形式类似矩阵乘法
因为可以这样变形:dp[i][j]=ix=1dp[i1][x]×(bitcount(a[x]dp[i][j]=∑x=1idp[i−1][x]×(bitcount(a[x] xorxor a[i])moda[i])mod 3==1)3==1)

所以状态转移方程可以写成这样
dp[i][1]dp[i][2]dp[i][n]=1010111×dp[i1][1]dp[i1][2]dp[i1][n][dp[i][1]dp[i][2]⋮dp[i][n]]=[100…111⋱…1]×[dp[i−1][1]dp[i−1][2]⋮dp[i−1][n]]

右边那个很多1和0 的矩阵是n×nn×n的状态转移矩阵,第i行第j列为1代表bitcount(a[x]bitcount(a[x] xorxor a[i])moda[i])mod 313≡1,否则为0
显然对角线全部为1(一个数异或它本身为0)
在DP之前我们可以预处理这个矩阵

由于最终答案为nj=1dp[k][j]∑j=1ndp[k][j],
状态转移方程可以改写为
dp[k][1]dp[k][2]dp[k][n]=1010111k1×dp[1][1]dp[1][2]dp[1][n][dp[k][1]dp[k][2]⋮dp[k][n]]=[100…111⋱…1]k−1×[dp[1][1]dp[1][2]⋮dp[1][n]]

首先很显然每个数构成一个满足条件的序列,所以dp[1][j]=1
所以没有必要存储dp数组,直接计算出矩阵k-1次方,再将矩阵内所有的值加起来即可

时间复杂度分析:
预处理时间复杂度O(n2)O(n2)
矩阵乘法时间复杂度O(n3)O(n3)
快速幂时间复杂度O(log2k)O(log2k)
总时间复杂度O(n3log2k)O(n3log2k)

代码:

//CF 691E
#include<iostream>
#include<cstdio>
#include<cstring> 
#define SIZE 105
#define maxn 105
using namespace std;
const long long mod=1000000007;
int n;
long long k;
long long num[maxn];
struct matrix {//矩阵
    int n;//长
    int m;//宽
    long long a[SIZE][SIZE];
    matrix() {//构造函数
        n=2;
        m=2;
        memset(a,0,sizeof(a));
    }
    matrix(int x,int y) {
        n=x;
        m=y;
        memset(a,0,sizeof(a));
    }
    void print() {
        for(int i=1; i<=n; i++) {
            for(int j=1; j<=m; j++) {
                printf("%d ",a[i][j]);
            }
            printf("
");
        }
    }
    void setv(int x) {//初始化
        if(x==0) {
            memset(a,0,sizeof(a));
        }
        if(x==1) {
            memset(a,0,sizeof(a));
            for(int i=1; i<=n; i++) a[i][i]=1;
        }
    }
    friend matrix operator *(matrix x,matrix y) {//矩阵乘法
        matrix tmp=matrix(x.n,y.m);
        for(int i=1; i<=x.n; i++) {
            for(int j=1; j<=y.m; j++) {
                tmp.a[i][j]=0;
                for(int k=1; k<=y.n; k++) {
                    tmp.a[i][j]+=(x.a[i][k]*y.a[k][j])%mod;
                }
                tmp.a[i][j]%=mod;
            }
        }
        return tmp;
    }
};
matrix fast_pow(matrix x,long long k) {//矩阵快速幂
    matrix ans=matrix(n,n);
    ans.setv(1);//初始化为1
    while(k>0) {//类似整数快速幂
        if(k&1) {
            ans=ans*x;
        }
        k>>=1;
        x=x*x;
    }
    return ans;
}
long long count_1(long long x) {//算1的个数
    long long ans=0;
    while(x>0){
        if(x&1) ans++;
        x/=2;
    }
    return ans;
}
int main() {
    scanf("%d %I64d",&n,&k);
    for(int i=1; i<=n; i++) scanf("%I64d",&num[i]);
    matrix xor_mat=matrix(n,n);
    for(int i=1; i<=n; i++) {
        for(int j=1; j<=n; j++) {
            if(count_1(num[i]^num[j])%3==0 )xor_mat.a[i][j]=1;
            else xor_mat.a[i][j]=0;
        }
    }
    xor_mat=fast_pow(xor_mat,k-1);
    long long ans=0;
    for(int i=1;i<=n;i++){
        for(int j=1;j<=n;j++){
            ans=ans+xor_mat.a[i][j];
        }
        ans%=mod;
    }
    printf("%I64d
",ans);
}
原文地址:https://www.cnblogs.com/birchtree/p/9858046.html