hdu5803

hdu5803

题意

给出四个整数 A B C D,问有多少个四元组 (a, b, c, d) 使 a + c > b + d 且 a + d >= b + c ,0 <= a <= A ,0 <= b <= B,0 <= c <= C,0 <= d <= D .

分析

可以用数位dp解决这个问题。

同时对四个数进行数位dp,采用记忆化搜索的形式,搜索过程中考虑剪枝,考虑到百位数时,如果 a + c - b - d >= 2,那么到十位数时(a + c - b - d 最少也才 -18,而前面到十位数会乘 10 即得到 20)一定满足条件了(加上 a + d - b - c 同理),而 a + c - b - d <= -2 那么到十位数时无论如何都不会满足条件了,可以直接剪枝。
但是每次要有 10 * 10 * 10 * 10 的状态转移,还要考虑优化,将四个数转化成二进制数,再进行 dp,前面的剪枝仍然可用。

code

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const ll MOD = 1e9 + 7;
int bit[5][65];
ll dp[62][5][5][16];
ll dfs(int l, int acbd, int adbc, int limit)
{
    if(l < 0) return acbd > 0 && adbc >= 0;
    ll& res = dp[l][acbd + 2][adbc + 2][limit];
    if(res != -1) return res;
    res = 0;
    int up[4];
    for(int i = 0; i < 4; i++)
    {
        up[i] = (limit >> i & 1) ? bit[i][l] : 1;
    }
    for(int a = 0; a <= up[0]; a++)
    {
        for(int b = 0; b <= up[1]; b++)
        {
            for(int c = 0; c <= up[2]; c++)
            {
                for(int d = 0; d <= up[3]; d++)
                {
                    int acbd_ = acbd, adbc_ = adbc, limit_ = 0;
                    acbd_ = min(acbd_ * 2 + a + c - b - d, 2);
                    adbc_ = min(adbc_ * 2 + a + d - b - c, 2);
                    if(acbd_ <= -2 || adbc_ <= -2) continue;
                    if(a == up[0] && (limit & 1)) limit_ |= 1;
                    if(b == up[1] && (limit >> 1 & 1)) limit_ |= 2;
                    if(c == up[2] && (limit >> 2 & 1)) limit_ |= 4;
                    if(d == up[3] && (limit >> 3 & 1)) limit_ |= 8;
                    (res += dfs(l - 1, acbd_, adbc_, limit_)) %= MOD;
                }
            }
        }
    }
    return res;
}
int main()
{
    int T;
    for(scanf("%d", &T); T--;)
    {
        memset(bit, 0, sizeof bit);
        memset(dp, -1, sizeof dp);
        ll A, B, C, D;
        scanf("%lld%lld%lld%lld", &A, &B, &C, &D);
        for(int i = 0; i < 62; i++)
        {
            bit[0][i] = A >> i & 1;
            bit[1][i] = B >> i & 1;
            bit[2][i] = C >> i & 1;
            bit[3][i] = D >> i & 1;
        }
        printf("%lld
", dfs(60, 0, 0, 15));
    }
    return 0;
}
原文地址:https://www.cnblogs.com/ftae/p/6896736.html