CSP-S 2019 Emiya 家今天的饭

64 pts

类似 乌龟棋 的思想,由于 (64pts)(m <= 3)非常小

我们可以设一个 (dp),建立 (m) 个维度存下每种物品选了几次:

  • (f[i][A][B][C]) 表示前 (i) 种烹饪方法,第 (1 / 2/ 3) 种主要食材各自选了 (A, B, C) 道菜的方案数。

状态转移:根据题意,每种烹饪方法最多选一道菜。

  • 不做菜 (f[i][A][B][C] += f[i - 1][A][B][C])
  • (1) 道第一种主要食材的菜 : (f[i][A][B][C] += f[i - 1][A - 1][B][C] * a_{i, 1})
  • (1) 道第二种主要食材的菜 : (f[i][A][B][C] += f[i - 1][A][B - 1][C] * a_{i, 2})
  • (1) 道第三种主要食材的菜 : (f[i][A][B][C] += f[i - 1][A][B][C - 1] * a_{i, 3})

答案:(sum_{A = 0}^{n}sum_{B = 0}^{n}sum_{C = 0}^{n}f[n][A][B][C] (max(A, B, C) <= lfloor(A + B + C) / 2 floor 且 A + B + C > 0))


小优化:发现所有状态只会从$A, B, C <= $ 自己的转移,所以可以用类似背包优化空间的思想,从大到小枚举状态,第一维可以滚动掉。

时间复杂度

最多选 (n) 道菜故时间复杂度 (O(n^{m + 1}))

#include <cstdio>
#include <iostream>
using namespace std;
const int N = 45, M = 6, P = 998244353;
int n, m, a[N][M];
typedef long long LL;
int f[N][N][N];
int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++)
        for (int j = 1; j <= m; j++) scanf("%d", &a[i][j]);
    
    f[0][0][0] = 1;
    for (int i = 1; i <= n; i++) {
        for (int A = i; ~A; A--) {
            for (int B = i; ~B; B--) {
                for (int C = i; ~C; C--) {
                    if(A && a[i][1]) f[A][B][C] = (f[A][B][C] + (LL)f[A - 1][B][C] * a[i][1]) % P; 
                    if(B && a[i][2]) f[A][B][C] = (f[A][B][C] + (LL)f[A][B - 1][C] * a[i][2]) % P; 
                    if(m == 3 && C && a[i][3]) f[A][B][C] = (f[A][B][C] + (LL)f[A][B][C - 1] * a[i][3]) % P; 
                }
            }
        }
    }
    int ans = 0;
    for (int A = n; ~A; A--) {
        for (int B = n; ~B; B--) {
            for (int C = n; ~C; C--) {
                int s = A + B + C;
                if(s > 0 && max(A, max(B, C)) <= s / 2) (ans += f[A][B][C]) %= P;
            }
        }
    }
    printf("%d
", ans);
    return 0;
}

84 pts

发现 (64pts) 后的 (m) 猛增,所以我们的算法一定不能具体记录每种主要食材选了多少了。

我们发现一个方案不合法,有且只会有一个主要食材 $ > $ 总数的一半,所以我们不妨考虑容斥,用所有方案数量 - 不合法数量。


求解所有方案数量

所有方案数量很好求,做一个分组背包即可:

(f[i][j]) 表示前 (i) 种烹饪方式,做了 (j) 道菜的方案数。

状态转移:

  • (i) 种烹饪方式不做菜:(f[i][j] += f[i - 1][j])
  • (i) 种烹饪方法做 (1) 道主要食材是 (k) 的菜:(f[i][j] += f[i - 1][j - 1] * a_{i, k})

所有方案数量 $ = sum_{j = 1}^{n}f[n][j]$

优化

  1. (i, j) 以来比它小的 (i', j'),第一维滚动掉
  2. 观察第二种放菜的转移:(f[i - 1][j - 1] * a_{i, 1} + f[i - 1][j - 1] * a_{i, 2} +...+ f[i - 1][j - 1] * a_{i, m} = f[i - 1][j - 1] * (a_{i, 1} + a_{i, 2} + ... + a_{i, m}))。我们可以 (O(nm)) 预处理 (s_i = a_{i, 1} + a_{i, 2} + ... + a_{i, m})。每个状态即可 (O(1)) 转移。
这步的时间复杂度

这步有 (n ^ 2) 个状态,(O(1)) 转移。 所以时间复杂度 (O(n ^ 2))

求解不合法数量

由于刚才我们发现的性质:所有不合法方案中有且只会有一个主要食材 $ > $ 总数的一半,我们称那个主要食材为越界食材,我们设越界食材为 (c)

所以我们不妨先用 (O(m)) 枚举 (c)

那么我们可以把其他食材归结为 符合条件的食材,我们便可以用一个维度来记录它选了多少啦~

(dp[i][j][k]) 为前 (i) 种烹饪方式,第 (c) 种(越界食材)选了 (j) 道,其他食材选了 (k) 道的方案数。

状态转移:

  • (i) 种烹饪方法不做菜:(dp[i][j][k] += dp[i - 1][j][k])(O(1)) 转移

  • 选第 (c) 种(越界食材):(dp[i][j][k] += dp[i - 1][j - 1][k] * a_{i, c} (j > 0)) (O(1)) 转移

  • 选其他食材:(dp[i][j][k] += sum_{u = 1, u != c}^{m}dp[i - 1][j][k - 1] * a_{i, u} (k > 0 ))。$O(m) $ 转移

对答案的贡献

(sum f[n][j][k] (j > k))

优化:

  1. 跟之前一样可以滚动掉第一维

  2. 第三种转移最耗时,考虑用求解所有方案数量优化2的思想:(sum_{u = 1, u != c}^{m}dp[i - 1][j][k - 1] * a_{i, u} = dp[i - 1][j][k - 1] * (sum_{u = 1, u != c}^{m}a_{i, u}) = dp[i - 1][j][k - 1] * (s_i - a_{i, c})) 我们就做到了 (O(1)) 转移。

这步的时间复杂度

(O(m)) 枚举越界食材后,做一个 (O(n ^ 3))(dp)。求解不合法数量的总时间复杂度为 (O(n ^ 3m))

总时间复杂度:(O(n ^ 3m))

#include <cstdio>
#include <iostream>
#include <cstring>
using namespace std;
typedef long long LL;
const int N = 105, M = 2005, P = 998244353;
int n, m, a[N][M], f[N], s[N];
int dp[N][N];
/*
dp[i][j] 表示不合法的选了 i 个,剩下的总共选了 j 个的方案数

*/
LL ans = 0;
/*
f[i] 表示做了 i 道菜的方案数
*/

void inline add(int &x, LL y) {
    x = (x + y) % P;
}
int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++)
        for (int j = 1; j <= m; j++) {
            scanf("%d", &a[i][j]);
            s[i] = ((LL)s[i] + a[i][j]) % P;
        }
    
    
    f[0] = 1;
    for (int i = 1; i <= n; i++)
        for (int j = i; j; j--)
            add(f[j], (LL)f[j - 1] * s[i]);
    
    for (int i = 1; i <= n; i++) ans = (ans + f[i]) % P;
    
    for (int c = 1; c <= m; c++) {
        memset(dp, 0, sizeof dp);
        dp[0][0] = 1;
        for (int i = 1; i <= n; i++) {
            for (int j = i; ~j; j--) {
                for (int k = i - j; ~k; k--) {
                    if(j) add(dp[j][k], (LL)dp[j - 1][k] * a[i][c]);
                    if(k) add(dp[j][k], (LL)dp[j][k - 1] * (s[i] - a[i][c]));
                }
            }
        }
        
        for (int j = 1; j <= n; j++) {
            for (int k = 0; k < j; k++) ans = (ans - dp[j][k] + P) % P;
        }
    }
    printf("%lld
", ans);
    return 0;
}

100pts

延续 (84pts) 的思想,求解不合法数量的 (O(n ^ 3m)) 拖累了我们,我们考虑优化。

我们不关系具体越界食材与其他食材选了多少。只用保证越界食材数 $ > $ 其他食材数数即为不合法状态。


不妨把这两个的差作为一个维度,这样即可让 (dp) 状态降一维:

  • (dp[i][j]) 表示前 (i) 中烹饪方法,越界食材数 $ - $ 其他食材数 为 (j) 的方案数。

状态转移:

  • (i) 种烹饪方法不选:(dp[i][j] += dp[i - 1][j])
  • 选越界食材 (c)(dp[i][j] += dp[i - 1][j - 1] * a_{i, c})
  • 选其他食材:(dp[i][j] += dp[i - 1][j + 1] * (s_i - a_{i, c}))

答案贡献:

(sum dp[n][j] (j > 0))

总时间复杂度 (O(n ^ 2m)) 完美通过本题。

(Tips)

  1. 注意做差有可能为负数,我们可以把所有状态加一个 (+n) 的偏移量就不会数组越界了。

  2. 不要忘记取模!!

#include <cstdio>
#include <iostream>
#include <cstring>
using namespace std;
typedef long long LL;
const int N = 105, M = 2005, P = 998244353;
int n, m, a[N][M], f[N], s[N];
int dp[N][N << 1];
/*
dp[i][j] 表示不合法的选了 i 个,剩下的总共选了 j 个的方案数

*/
LL ans = 0;
/*
f[i] 表示做了 i 道菜的方案数
*/

void inline add(int &x, LL y) {
    x = (x + y) % P;
}
int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++)
        for (int j = 1; j <= m; j++) {
            scanf("%d", &a[i][j]);
            s[i] = ((LL)s[i] + a[i][j]) % P;
        }
    
    
    f[0] = 1;
    for (int i = 1; i <= n; i++)
        for (int j = i; j; j--)
            add(f[j], (LL)f[j - 1] * s[i]);
    
    for (int i = 1; i <= n; i++) ans = (ans + f[i]) % P;
    
    for (int c = 1; c <= m; c++) {
        memset(dp, 0, sizeof dp);
        dp[0][n] = 1;
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= n + i; j++) {
                dp[i][j] = (dp[i - 1][j] + (LL)dp[i - 1][j - 1] * a[i][c] + (LL)dp[i - 1][j + 1] * (s[i] - a[i][c])) % P;
            }
        }
        
        for (int j = n + 1; j <= n * 2; j++) ans = (ans - dp[n][j] + P) % P;
    }
    printf("%lld
", ans);
    return 0;
}
原文地址:https://www.cnblogs.com/dmoransky/p/11916846.html