【10.3校内测试【国庆七天乐!】】【DP+组合数学/容斥】【spfa多起点多终点+二进制分类】

最开始想的暴力DP是把天数作为一个维度所以怎么都没有办法优化,矩阵快速幂也是$O(n^3)$会爆炸。

但是没有想到另一个转移方程:定义$f[i][j]$表示每天都有值的$i$天,共消费出总值$j$的方案数。然后答案就是

所以每次维护前缀和就可以$O(1)$转移了。

注意前缀和的初值。

#include<bits/stdc++.h>
#define LL long long
#define mod 998244353
using namespace std;

int n, m;
LL d;
LL dp[2005][2005], sum[2005][2005];

LL mpow(LL a, LL b) {
    LL ans = 1;
    for(; b; b >>= 1, a = a * a % mod)
        if(b & 1)    ans = ans * a % mod;
    return ans;
}

LL rev(LL a) {
    return mpow(a, mod - 2);
}

LL comb(LL p, int q) {
    LL a = 1, b = 1;
    for(LL i = p - q + 1; i <= p; i ++)
        a = i % mod * a % mod;
    for(int i = 1; i <= q; i ++)
        b = b * i % mod;
    LL ans = a * rev(b) % mod;
    return ans;
}

int main() {
    freopen("contract.in", "r", stdin);
    freopen("contract.out", "w", stdout);
    while(cin >> n >> d >> m) {
        if(n == 0 && d == 0 && m == 0)    break;
        d %= mod;
        int now = 0;
        memset(sum, 0, sizeof(sum));
        memset(dp, 0, sizeof(dp));
        for(int i = 1; i < m && i <= n; i ++)
            dp[1][i] = 1;
        for(int i = 1; i <= n; i ++)
            sum[1][i] = sum[1][i-1] + dp[1][i];
        for(int i = 2; i <= n && i <= d; i ++) {
            for(int j = 1; j <= n; j ++) {
                if(j - m > 0)    dp[i][j] = (sum[i-1][j-1] - sum[i-1][j-m] + mod) % mod;
                else dp[i][j] = sum[i-1][j-1];
                sum[i][j] = (sum[i][j-1] + dp[i][j]) % mod;
            }
        }
        LL ans = 0;    
        for(int i = 1; i <= n && i <= d; i ++) {
            LL tmp = comb(d, i);
            ans = (ans + tmp * dp[i][n] % mod) % mod;
        }
        printf("%lld
", ans);
    }
    return 0;
} 

起点确定的最小环。

我们可以发现,因为环的起点和终点都是1,所以题目实际是找与1相连的一个起点和一个终点(因为要保证没有走重边,所以起点和终点一定不同),而对于两个不同的数,二进制位上一定有至少一位不相同,所以可以按每一位,将二进制中当前位不同的点分成两组,代表当前起点和终点,每次跑一遍多起点多终点的$Spfa$,统计最小答案即可。

【注意】不能把每次跑完得到的起点终点直接两两配对,因为两点不一定能相互到达,还是应该在$Spfa$中赋初值跑完。

#include<bits/stdc++.h>
#define oo 0x3f3f3f3f
using namespace std;

int n, m, tot;

struct Node {
    int u, v, nex, w;
    Node(int u = 0, int v = 0, int nex = 0, int w = 0) :
        u(u), v(v), nex(nex), w(w) { }
} Edge[800005];

int stot, h[100005];
void add(int u, int v, int s) {
    Edge[++stot] = Node(u, v, h[u], s);
    h[u] = stot;
}

int vis[100005], dis[100005], S[100005], T[100005], nums, numt, W[800005], rt[100005];
queue < int > q;
void Spfa() {
    memset(vis, 0, sizeof(vis));
    memset(dis, 0x3f3f3f3f, sizeof(dis));
    for(int i = 1; i <= nums; i ++)    q.push(S[i]), vis[S[i]] = 1, dis[S[i]] = W[S[i]];
    while(!q.empty()) {
        int x = q.front(); q.pop(); vis[x] = 0;
        for(int i = h[x]; i; i = Edge[i].nex) {
            int v = Edge[i].v;
            if(dis[v] > dis[x] + Edge[i].w && v != 1) {
                dis[v] = dis[x] + Edge[i].w;
                if(!vis[v]) {
                    vis[v] = 1;    q.push(v);
                }
            }
        }
    }
}

int main() {
    freopen("leave.in", "r", stdin);
    freopen("leave.out", "w", stdout);
    int t;
    scanf("%d", &t);
    while(t --) {
        scanf("%d%d", &n, &m);
        stot = 0, tot = 0;
        memset(h, 0, sizeof(h));
        memset(W, 0, sizeof(W));
        memset(rt, 0, sizeof(rt));
        int ans = 0x3f3f3f3f;
        for(int i = 1; i <= m; i ++) {
            int a, b, c;
            scanf("%d%d%d", &a, &b, &c);
            add(a, b, c);    add(b, a, c);
            if(b < a)    swap(a, b);
            if(a == 1) rt[++tot] = b, W[b] = c;
        }
        if(tot <= 1) {
            printf("-1
"); continue;
        }
        sort(rt + 1, rt + 1 + tot);
        int M = rt[tot];
        int tmp = 0;
        while(M) {
            memset(S, 0, sizeof(S));
            memset(T, 0, sizeof(T));
            nums = 0; numt = 0;
            int t = M & 1;
            for(int i = 1; i <= tot; i ++)
                if(((rt[i] >> tmp) & 1) == t)    S[++nums] = rt[i];
                else    T[++numt] = rt[i];
            Spfa();
            for(int i = 1; i <= numt; i ++)
                ans = min(ans, W[T[i]] + dis[T[i]]);
            M >>= 1; tmp ++;
        }
        if(ans < oo)    printf("%d
", ans);
        else printf("-1
");
    }
    return 0;
}
原文地址:https://www.cnblogs.com/wans-caesar-02111007/p/9740618.html