[Luogu]P2182 翻硬币

Before the Beginning

转载请将本段放在文章开头显眼处,如有二次创作请标明。

原文链接:https://www.codein.icu/lp2182/

题意

(N) 个硬币,一次翻转恰好 (M) 个,恰好 (K) 次后到达目标状态的方案数。

朴素解法

先将正反面转化为是否与目标状态相同方便处理。
定义 (f(i,k)) 为翻转了 (i) 次,与目标状态有 (k) 个硬币相同的方案数。
考虑转移,枚举翻转 (m) 个硬币中,(j) 个硬币从相同翻为不同, (m - j) 个硬币从不同翻为相同。
那么计算得出新的相同数量即为 (k + m - 2 imes j),判断是否在合法范围内。
从状态 (f(i-1,k)) 转移到 (f(i,k + m - 2 imes j)) 时,要在 (k) 选择 (j) 个相同的硬币翻转为不同,(n - k) 中选择 (m - j) 个硬币翻转为相同,乘上相应的组合数即可。
可以得到朴素的 DP 解法,复杂度约为 (O(kn^2))
代码实现中使用了滚动数组优化空间。

#include <stdio.h>
#include <memory.h>
const int maxn = 110;
const int mod = 1000000007;
int n,k,m;
int dif;
char s1[maxn],s2[maxn];
long long f[2][maxn];//f[i] i sames.
long long c[maxn][maxn];
int main()
{   
    scanf("%d %d %d",&n,&k,&m);
    c[1][1] = 1;
    for(int i = 2;i<=n+1;++i)
        for(int j = 1;j<=i;++j)
            c[i][j] = (c[i-1][j] + c[i-1][j-1]) % mod;
    scanf("%s",s1+1);
    scanf("%s",s2+1);
    for(int i = 1;i<=n;++i) dif += (s1[i] != s2[i]);
    int now = 0,last = 1;
    f[last][n - dif] = 1;
    while(k--)
    {
        memset(f[now],0,sizeof(f[now]));
        for(int i = 0;i<=n;++i)// origin same
            for(int j = 0;j<=m;++j)//choose j same->diff, n - j diff->same
            {
                int num = i - j + m - j;
                if(num > n || num < 0) continue;
                f[now][num] += ((f[last][i] * c[i+1][j+1]) % mod * c[n-i+1][m-j+1]) % mod;
                if(f[now][num] > mod) f[now][num] %= mod;
            }
        now ^= 1,last ^= 1;
    }
    printf("%lld",f[last][n] % mod);
    return 0;
}

矩乘优化

不难发现,每次的转移都是在累加上次的某个状态的结果乘上某个系数,并且系数是不变的。
(f(i,k + m - 2 imes j) += f(i - 1,k) imes ldots)
那么整理一下,一定可以写成:
(f(i,k) = sum f(i-1,j) imes A(j,k))
可以使用矩阵乘法来加速DP,复杂度减少为 (O(n^3 log k))
具体地,用原先的DP过程来构造 (A) 矩阵,用矩阵快速幂获得 (A^k),再乘上初始状态矩阵即可获得最终状态矩阵。
该方法在 (n) 较小, (k) 较大时表现出色,但在本题中并无优势。

代码实现中,将矩阵整体下标加一避免零的出现。

#include <cstdio>
const int maxn = 110;
const long long mod = 1000000007;
int n, m, k, cnt;
char s1[maxn], s2[maxn];
long long c[maxn][maxn];
inline void init()
{
    c[1][1] = 1ll;
    for (int i = 2; i <= n + 5; ++i)
        for (int j = 1; j <= i; ++j)
            c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % mod;
}
struct matrix
{
    int n, m;
    long long a[maxn][maxn];
    matrix(int n, int m)
    {
        this->n = n, this->m = m;
        for (int i = 1; i <= n; ++i) 
            for (int j = 1; j <= m; ++j)
                a[i][j] = 0;
    }
    matrix operator*(const matrix b)
    {
        matrix res(n, b.m);
        for (int i = 1; i <= n; ++i)
            for (int j = 1; j <= m; ++j)
                for (int k = 1; k <= b.m; ++k)
                    res.a[i][k] = (res.a[i][k] + a[i][j] * b.a[j][k]) % mod;
        return res;
    }
};
signed main()
{
    scanf("%d %d %d", &n, &k, &m), init();
    scanf("%s", s1 + 1), scanf("%s", s2 + 1);
    for (int i = 1; i <= n; ++i) cnt += (s1[i] == s2[i]);
    matrix A(n + 1, n + 1), S(1, n + 1), T(n + 1, n + 1);
    for (int i = 0; i <= n; ++i) for (int j = 0; j <= m; ++j)
    {
        int num = i + m - 2 * j;
        if (num > n || num < 0) continue;
        A.a[i + 1][num + 1] = (A.a[i + 1][num + 1] + c[i + 1][j + 1] * c[n - i + 1][m - j + 1]) % mod;
    }
    S.a[1][cnt + 1] = 1;
    for (int i = 1; i <= T.n; ++i) T.a[i][i] = 1;
    for (; k; k >>= 1)
    {
        if (k & 1) T = T * A;
        A = A * A;
    }
    S = S * T;
    printf("%lld
", S.a[1][n + 1]);
    return 0;
}
原文地址:https://www.cnblogs.com/Clouder-Blog/p/lp2182.html