HDU-4529 郑厂长系列故事——N骑士问题 状态压缩DP

题意:给定一个合法的八皇后棋盘,现在给定1-10个骑士,问这些骑士不能够相互攻击的拜访方式有多少种。

分析:一开始想着搜索写,发现该题和八皇后不同,八皇后每一行只能够摆放一个棋子,因此搜索收敛的很快,但是骑士的话就需要对每一个格子分两种情况进行,情况非常的多,搜索肯定是会超时的。状态压缩DP就是另外一个思路的,理论上时间复杂度是8*n*2^24,但是由于限制比较多,也就能够解决了。设dp[i][j][p][q]表示第i-1行压缩后的状态是p,第i行压缩后的状态为q,且之前一共使用了j个骑士的方案数。按照题意递推即可。

#include <cstdio>
#include <cstring>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;

const int LIM = 1<<8;
const int M = 10;
int n;
int dp[2][M+1][LIM][LIM];
// dp[i][j][p][q]表示第i-1行状态为p,第i行状态为q,并且一共使用j个骑士的状态数
int G[M];
int tot[LIM];
char f1[LIM][LIM]; // 相邻两层两个状态之间是否冲突 
char f2[LIM][LIM]; // 与上上行两个状态之间是否冲突 

void pre() {
    for (int i = 0; i < LIM; ++i) {
        for (int j = 0; j < 8; ++j) {
            if (i & (1 << j)) ++tot[i];
        }
        for (int j = 0; j < LIM; ++j) {
            if ((i>>2)&j || (j>>2)&i) f1[i][j] = 1;
            if ((i>>1)&j || (j>>1)&i) f2[i][j] = 1;
        }
    }
}

void solve() {
    int cur = 0, nxt = 1;
    memset(dp, 0, sizeof (dp));
    dp[cur][0][0][0] = 1;
    for (int i = 0; i < 8; ++i) { // 由dp[i]来推导dp[i+1]
        for (int j = 0; j <= n; ++j) {
            for (int p = 0; p < LIM; ++p) {
                for (int q = 0; q < LIM; ++q) {
                    if (!dp[cur][j][p][q]) continue;
                    for (int z = 0; z < LIM; ++z) {
                        if ((z & G[i+1]) != z) continue;
                        if (tot[z] + j > n) continue;
                        if (i >= 1 && f1[q][z]) continue;
                        if (i >= 2 && f2[p][z]) continue;
                        dp[nxt][tot[z]+j][q][z] += dp[cur][j][p][q];
                    }
                }
            }
        }
        memset(dp[cur], 0, sizeof (dp[cur]));
        swap(cur, nxt);
    }
    int ret = 0;
    for (int i = 0; i < LIM; ++i) {
        for (int j = 0; j < LIM; ++j) {
            ret += dp[cur][n][i][j];
        }
    }
    printf("%d
", ret);
}

int main() {
    int T;
    char str[10];
    pre();
    scanf("%d", &T);
    while (T--) {
        memset(G, 0, sizeof (G));
        scanf("%d", &n);
        for (int i = 1; i <= 8; ++i) {
            scanf("%s", str);
            for (int j = 0; j < 8; ++j) {
                G[i] <<= 1;
                if (str[j] == '.') G[i] |= 1;
            }
        }
        solve();
    }
    return 0;
}
原文地址:https://www.cnblogs.com/Lyush/p/3416032.html