Codeforces 908G New Year and Original Order 数位dp

用类似于数位dp的方式, 去求每个数字的贡献。。 好像我写得巨麻烦。

其实转化一下之后, 有很好写的方法。

#include<bits/stdc++.h>
#define LL long long
#define LD long double
#define ull unsigned long long
#define fi first
#define se second
#define mk make_pair
#define PLL pair<LL, LL>
#define PLI pair<LL, int>
#define PII pair<int, int>
#define SZ(x) ((int)x.size())
#define ALL(x) (x).begin(), (x).end()
#define fio ios::sync_with_stdio(false); cin.tie(0);

using namespace std;

const int N = 700 + 7;
const int inf = 0x3f3f3f3f;
const LL INF = 0x3f3f3f3f3f3f3f3f;
const int mod = 1e9 + 7;
const double eps = 1e-8;
const double PI = acos(-1);

template<class T, class S> inline void add(T& a, S b) {a += b; if(a >= mod) a -= mod;}
template<class T, class S> inline void sub(T& a, S b) {a -= b; if(a < 0) a += mod;}
template<class T, class S> inline bool chkmax(T& a, S b) {return a < b ? a = b, true : false;}
template<class T, class S> inline bool chkmin(T& a, S b) {return a > b ? a = b, true : false;}

int n, cnt, Pow[N];
int v[N];
char s[N];

int dp[N];
int f[N][N][2];
int g[N][N][2];
int sum[N][N][2];
int pre[N][N][2];

int d, w;

int getRet1(int p, int ban, bool limit) {
    if(p == -1) return 1;
    if(!limit && ~dp[p]) return dp[p];
    int ret = 0;
    int up = limit ? v[p] : 9;
    for(int i = 0; i <= up; i++) {
        if(i == ban) continue;
        add(ret, getRet1(p - 1, ban, limit && (i == up)));
    }
    if(!limit) dp[p] = ret;
    return ret;
}

int getRet2(int p, int big, int big2, bool have, bool limit) {
    if(p == -1) {
        return (have && big > w) || (have && big2 <= w);
    }
    if(!limit) {
        int need = max(0, w - big + 1);
        int need2 = min(p + 1, w - big2);
        int ret = 0;
        if(have) ret = (sum[p][need][0] + sum[p][need][1]) % mod;
        else ret = sum[p][need][1];
        int gg = ret;
        if(need2 >= 0) {
            if(have) add(ret, (pre[p][need2][0] + pre[p][need2][1]) % mod);
            else add(ret, pre[p][need2][1]);
        }
        return ret;
    }
    int up = limit ? v[p] : 9;
    int ret = 0;
    for(int i = 0; i <= up; i++) {
        add(ret, getRet2(p - 1, big + (i > d), big2 + (i >= d), have || (i == d), limit && (i == up)));
    }
    return ret;
}

int solve(int x) {
    int ret = 0;
    memset(f, 0, sizeof(f));
    memset(g, 0, sizeof(g));
    memset(dp, -1, sizeof(dp));
    memset(sum, 0, sizeof(sum));
    f[0][1][0] = 9 - x;
    f[0][0][1] = 1;
    f[0][0][0] = x;
    for(int i = 1; i < n; i++) {
        for(int j = i + 1; j >= 0; j--) {
            add(f[i][j][0], 1LL * f[i - 1][j][0] * x % mod);
            if(j) add(f[i][j][0], 1LL * f[i - 1][j - 1][0] * (9 - x) % mod);
            add(f[i][j][1], 1LL * f[i - 1][j][1] * (x + 1) % mod);
            if(j) add(f[i][j][1], 1LL * f[i - 1][j - 1][1] * (9 - x) % mod);
            add(f[i][j][1], f[i - 1][j][0]);
        }
    }
    g[0][1][0] = 9 - x;
    g[0][1][1] = 1;
    g[0][0][0] = x;
    for(int i = 1; i < n; i++) {
        for(int j = i + 1; j >= 0; j--) {
            add(g[i][j][0], 1LL * g[i - 1][j][0] * x % mod);
            if(j) add(g[i][j][0], 1LL * g[i - 1][j - 1][0] * (9 - x) % mod);
            add(g[i][j][1], 1LL * g[i - 1][j][1] * x % mod);
            if(j) add(g[i][j][1], 1LL * g[i - 1][j - 1][1] * (10 - x) % mod);
            if(j) add(g[i][j][1], 1LL * g[i - 1][j - 1][0]);
        }
    }
    for(int i = 0; i < n; i++) {
        for(int j = i + 1; j >= 0; j--) {
            sum[i][j][0] = (f[i][j][0] + sum[i][j + 1][0]) % mod;
            sum[i][j][1] = (f[i][j][1] + sum[i][j + 1][1]) % mod;
        }
    }
    for(int i = 0; i < n; i++) {
        for(int j = 0; j <= i + 1; j++) {
            pre[i][j][0] = g[i][j][0];
            if(j) add(pre[i][j][0], pre[i][j - 1][0]);
            pre[i][j][1] = g[i][j][1];
            if(j) add(pre[i][j][1], pre[i][j - 1][1]);
        }
    }
    int ncnt = (cnt - getRet1(n - 1, x, 1) + mod) % mod;
    for(int i = 0; i < n; i++) {
        int tmp = ncnt;
        d = x, w = i;
        sub(tmp, getRet2(n - 1, 0, 0, 0, 1));
        add(ret, 1LL * Pow[i] * tmp % mod * x % mod);
    }
    return ret;
}

int main() {
    for(int i = Pow[0] = 1; i < N; i++)
        Pow[i] = 1LL * Pow[i - 1] * 10 % mod;
    scanf("%s", s);
    n = strlen(s);
    reverse(s, s + n);
    for(int i = n - 1; i >= 0; i--) {
        v[i] = s[i] - '0';
        cnt = 1LL * cnt * 10 % mod;
        add(cnt, v[i]);
    }
    add(cnt, 1);
    int ans = 0;
    for(int i = 1; i <= 9; i++) add(ans, solve(i));
    printf("%d
", ans);
    return 0;
}

/*
*/
原文地址:https://www.cnblogs.com/CJLHY/p/11074319.html