HDU-4507 吉哥系列故事——恨7不成妻 数位DP

题意:给定区间[L, R]求区间内与7无关数的平方和。一个数当满足三个规则之一则认为与7有关:
1、整数中某一位是7;
2、整数的每一位加起来的和是7的整数倍;
3、这个整数是7的整数倍;

分析:初看起来确实有点麻烦,数位DP还是很容易看出来的,需要维护好三个值dp[ i ][ j ][ k ].num表示数位和为对7的余数为 j ,前面确定的数对7的余数为 k 的情况下, i 位任意与7无关的数一共有多少个;同理 dp[ i ][ j ][ k ].sum 表示这些数的和为多少;dp[ i ][ j ][ k ].sqr 表示这些数的平方和为多少,这三者之间是可以递推的,详见代码。个人觉得将区间左右边界同时代入求解更加优美。

#include <cstdlib>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

typedef long long LL;
const int mod = int(1e9)+7;
int hbit[20], lbit[20];
int _POW[20];
struct STATUS {
    int num, sum, sqr;
    bool flag;
    STATUS(int _num, int _sum, int _sqr) : num(_num), sum(_sum), sqr(_sqr) {}
    STATUS() : flag(false) {}
    // sum = sum {cur*10^p*num' + sum'}
    // sqr = sum {num'*(cur*10^p)^2 + sqr' + 2*cur*10^p*sum'}
}dp[20][7][7];

STATUS cal(int p, int srem, int mrem, bool lb, bool hb) {
    if (p == 0) {
        if (srem != 0 && mrem != 0) return STATUS(1, 0, 0);
        else return STATUS(0, 0, 0);
    }
    if (!lb && !hb && dp[p][srem][mrem].flag) {
        return dp[p][srem][mrem];
    }
    STATUS ret(0, 0, 0), tmp;
    int sta = lb ? lbit[p] : 0;
    int end = hb ? hbit[p] : 9;
    for (int i = sta; i <= end; ++i) {
        if (i == 7) continue;
        tmp = cal(p-1, (srem+i)%7, (mrem*10+i)%7, lb&&i==sta, hb&&i==end);
        ret.num = (1LL*ret.num + 1LL*tmp.num) % mod;
        ret.sum = (1LL*ret.sum + 1LL*i*_POW[p-1]%mod*tmp.num%mod+tmp.sum) % mod;
        ret.sqr = (1LL*ret.sqr + 1LL*i*_POW[p-1]%mod*i%mod*_POW[p-1]%mod*tmp.num%mod + 1LL*tmp.sqr + 2LL*i*_POW[p-1]%mod*tmp.sum%mod)%mod;
    }
    if (!lb && !hb) {
        dp[p][srem][mrem] = ret;
        dp[p][srem][mrem].flag = true;
    }
    return ret;
}

int count(LL l, LL r) {
    memset(lbit, 0, sizeof (lbit));
    memset(hbit, 0, sizeof (hbit));
    int lidx = 1, hidx = 1;
    while (l) {
        lbit[lidx++] = l % 10;
        l /= 10;
    }
    while (r) {
        hbit[hidx++] = r % 10;
        r /= 10;
    }
    return cal(max(lidx-1, hidx-1), 0, 0, true, true).sqr;
}

int main() {
    _POW[0] = 1;
    for (int i = 1; i < 20; ++i) {
        _POW[i] = (1LL*_POW[i-1]*10) % mod;
    }
    int T;
    LL l, r;
    scanf("%d", &T);
    while (T--) {
        scanf("%I64d %I64d", &l, &r);
        printf("%d
", count(l, r));
    }
    return 0;
}
原文地址:https://www.cnblogs.com/Lyush/p/3307889.html