[ZJOI 2019] 麻将

[题目链接]

        https://loj.ac/problem/3042

[题解]

       首先考虑将期望拆开 , 有 E(x) = sigma { P (x > i) }

       我们需要求出i张牌仍不能胡牌的概率 , 显然可以转化为求方案数。

       直接动态规划是不好做的 , 但如果我们能将当前手上的麻将状态压成一个数 , 那么就可以设 dp(i , j , k)表示前i种麻将 , 状态为j , 一共选了k张牌的方案数。

       如何将状态压缩呢?

       考虑给定一手牌 , 如何判定它是否能胡牌?

       我们发现对于每个i , 形如"i , i + 1 , i + 2"这样的"顺子"是不超过3个的 , 因为如果超过可以用形如"i , i , i"这样类型的“面子”来替换。 那么可以设f(i , j , k , l)表示前i种牌 , (i - 1)开头的“顺子”有j个 , i开头的“顺子”有k个 , 另有l ( l <= 1) 个对子。 这样一共有O(18N)种状态 , 直接转移即可 , 具体细节不再赘述。

       进一步观察这个动态规划 , 我们发现i其实是无关紧要的 , 这个过程的本质就是每次新加一种类型的麻将 ,更新一个3 * 3的矩阵。 不妨考虑建立有限状态自动机(DFA) , 直接将这个3 * 3的矩阵做为状态 , 将"胡"的节点做为终止节点。 暴力构建这个自动机 , 发现其状态数很小 ,为3956

       那么我们就解决了状态压缩的问题。

       回到刚才的思路 , 不妨设dp(i , j , k)表示加入了i种类型的麻将 , 现在在自动机上j号节点 , 一共选了k张牌的方案数。 用一些组合数学的技巧就可以实现转移 , 具体细节不再赘述。

       那么这道题就做完了 , 时间复杂度 : O(N ^ 2M) (M为自动机的状态数)

[代码]

       

/*
      Author : @evenbao
      Created : 2020 / 07 / 29 
*/

#ifdef _MSC_VER
#define _CRT_SECURE_NO_WARNINGS
#endif

#include<bits/stdc++.h>

using namespace std;

typedef long long LL;

#define pii pair<int , int>
#define mp make_pair
#define fi first
#define se second

const int N = 1e2 + 5;
const int M = 4e3 + 5;
const int mod = 998244353;

template <typename T> inline void chkmax(T &x , T y) { x = max(x , y); }
template <typename T> inline void chkmin(T &x , T y) { x = min(x , y); }
template <typename T> inline void read(T &x) {
    T f = 1; x = 0;
    char c = getchar();
    for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
    for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0';
    x *= f;
}

inline void inc(int &x , int y) { 
        x = x + y < mod ? x + y : x + y - mod; 
}
inline void dec(int &x , int y) {
        x = x - y >= 0 ? x - y : x - y + mod;
}
inline int quickpow(int a , int n) {
        int b = a , res = 1;
        for (; n; n >>= 1 , a = (LL) a * a % mod)
                if (n & 1) res = (LL) res * a % mod;
        return res;
}
struct State {
        int dp[3][3];
        State() {
                memset(dp , 255 , sizeof(dp));
        }
        friend bool operator < (State a , State b) {
                for (int i = 0; i < 3; ++i)
                        for (int j = 0; j < 3; ++j)
                                if (a.dp[i][j] != b.dp[i][j])
                                        return a.dp[i][j] < b.dp[i][j];
                return false;
        }
        friend State Max(State a , State b) {
                State c;
                for (int i = 0; i < 3; ++i)
                        for (int j = 0; j < 3; ++j)
                                c.dp[i][j] = max(a.dp[i][j] , b.dp[i][j]);
                return c;
        }
        friend State Trans(State a , int b) {
                State c;
                for (int i = 0; i < 3; ++i)
                        for (int j = 0; j < 3; ++j)
                                if (~a.dp[i][j])
                                        for (int k = 0; k < 3 && i + j + k <= b; ++k)
                                                chkmax(c.dp[j][k] , min(i + a.dp[i][j] + (b - i - j - k) / 3 , 4));
                return c;
        }
} ;

struct Mahjong {
        pair < State , State > god;
        int cnt;
        Mahjong() {
                memset(god.first.dp , 255 , sizeof(god.first.dp));
                memset(god.second.dp , 255 , sizeof(god.second.dp));
                god.first.dp[0][0] = cnt = 0;
        }
        friend bool operator < (Mahjong a , Mahjong b) {
                return a.cnt != b.cnt ? a.cnt < b.cnt : a.god < b.god;
        }
        friend Mahjong Trans(Mahjong a , int b) {
                a.cnt = min(a.cnt + (b >= 2) , 7);
                a.god.second = Trans(a.god.second , b);
                if (b >= 2)
                        a.god.second = Max(a.god.second , Trans(a.god.first , b - 2));
                a.god.first = Trans(a.god.first , b);
                return a;
        }
        inline bool right() {
                if (cnt >= 7) return 1;
                for (int i = 0; i < 3; ++i)
                        for (int j = 0; j < 3; ++j)
                                if (god.second.dp[i][j] == 4) return 1;
                return 0;
        }
} mahjong[M];

int n , tot;
map < Mahjong , int > idx;
bool win[M];
int org[N] , dp[N][M][4 * N] , trans[M][5] , fac[M] , c[M][5];

inline void Dfs_Mahjong(Mahjong now) {
        if (idx.find(now) != idx.end()) return;
        mahjong[++tot] = now;
        win[tot] = now.right();
        idx[now] = tot;
        for (int i = 0; i <= 4; ++i)
                Dfs_Mahjong(Trans(now , i));
}

int main() {
        
        fac[0] = 1;
        for (int i = 1; i < M; ++i) 
                fac[i] = (LL) fac[i - 1] * i % mod;
        for (int i = 0; i < M; ++i) {
                c[i][0] = 1;
                for (int j = 1; j <= min(i , 4); ++j)
                        c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % mod;
                continue;
        }
        Dfs_Mahjong(Mahjong());
        for (int i = 1; i <= tot; ++i)
        for (int j = 0; j <= 4; ++j)
                trans[i][j] = idx[Trans(mahjong[i] , j)];
        scanf("%d" , &n);
        for (int i = 0; i < 13; ++i) {
                int x; scanf("%d%*d" , &x);
                ++org[x];
        }
        dp[0][1][0] = 1;
        for (int i = 0 , cp = 0; i < n; ++i) {
                cp += org[i + 1];
                for (int j = 1; j <= tot; ++j) {
                        for (int l = org[i + 1]; l <= 4; ++l) {
                                int *nf = dp[i + 1][trans[j][l]] , *ff = dp[i][j];
                                int tmp = (LL) c[4 - org[i + 1]][l - org[i + 1]] * fac[l - org[i + 1]] % mod;
                                for (int k = 0; k + l <= 4 * n; ++k) {
                                        if (!ff[k]) continue;
                                        inc(nf[k + l] , (LL) ff[k] * tmp % mod * c[k + l - cp][l - org[i + 1]] % mod);
                                }
                        }
                }
        }
        int ans = 0 , dw = 1;
        for (int i = 13; i <= 4 * n; ++i) {
                int up = 0;
                for (int j = 1; j <= tot; ++j)
                        if (!win[j]) inc(up , dp[n][j][i]);
                inc(ans , (LL) up * quickpow(dw , mod - 2) % mod);
                dw = (LL) dw * (4 * n - i) % mod;
        }
        printf("%d
" , ans);
      return 0;
}
原文地址:https://www.cnblogs.com/evenbao/p/13398233.html