2020.10.05提高组模拟

6813. 【2020.10.05提高组模拟】路哥


Description

给定一棵有 n 个点的树,问有多少种断边方案使得包括 1 的连通块权值和为 (k),求断边方案除以总方案数((2^n - 1)) (mod 998244353)

Data Constraint

(1 leq n, k leq 5000),所有点的权值小于等于5000

Solution

暴力设 (f_{u, i}) 表示以 u 为根的子树中,u 所在的连通块的权值和为 i 的概率和。
转移显然 (f_{u, i} = sum{f_{u, j} * f_{v, i - j} / 2} + f_{u, i} / 2)
时间复杂度 (O(nk^2)) 所以需要考虑如何优化。

Method1

可以发现如果 (v) 不选,那么以它为根子树内的点就都不能选。

优化:每次 dg 时将 (u) 当前的方案下传给儿子 (v),做完 (v) 后,(u) 的概率和就是 (v) 的概率和加上不选 (v) 的概率和。
正确性:每次将 (u) 的概率和下传给 (v),计算出来的其实就是连接 ((u, v)) 这一条边的概率和,所以回溯的时候加上断开 ((u, v)) 的概率和。

Method2

其实只需要在转移的时候掐紧上下界就可以了,树上背包不会了唔。

Code

#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

#define N 5000
#define Mod 998244353
#define p 499122177

#define fo(i, x, y) for(int i = x; i <= y; i ++)
#define Fo(i, u) for(int i = head[u]; i; i = edge[i].next)

void read(int &x) {
    char ch = getchar(); x = 0;
    while (ch < '0' || ch > '9') ch = getchar();
    while (ch >= '0' && ch <= '9') x = (x << 1) + (x << 3) + ch - 48, ch = getchar();
}

struct EDGE { int next, to; } edge[N << 1];

int f[N + 1][N + 1], g[N + 1], head[N + 1], a[N + 1];

int n, m;

int cnt_edge = 1;
void Add(int u, int v) { edge[ ++ cnt_edge ] = (EDGE) { head[u], v }, head[u] = cnt_edge; }
void Link(int u, int v) { Add(u, v), Add(v, u); }

void Dfs(int u, int fa) {
    if (fa) {
        fo(i, a[u], m)
            f[u][i] = 1ll * f[fa][i - a[u]] * p % Mod;
    } else
        f[u][a[u]] = 1;
    Fo(i, u) if (edge[i].to != fa) {
        int v = edge[i].to;
        Dfs(v, u);
        fo(i, 0, m)
            f[u][i] = (f[v][i] + 1ll * f[u][i] * p % Mod) % Mod;
    }
}

int main() {
    freopen("luge.in", "r", stdin);
    freopen("luge.out", "w", stdout);

    read(n), read(m);
    fo(i, 1, n) read(a[i]);
    for (int i = 1, x, y; i < n; i ++)
        read(x), read(y), Link(x, y);

    Dfs(1, 0);

    printf("%d
", f[1][m]);

    return 0;
}

6811. 【2020.10.05提高组模拟】密电


Description

有一个长度为 n 的数列 a,现将它元素之间两两相加得到一个长度为 (frac{n(n - 1)}{2}) 的数列 b。
给出 b,求所有可能的 a。

Data Constraint

(3 leq n leq 500),b 中任意元素不大于 (2 * 10^8)

Solution

先将 b 从小到大排序。
可以确定 (a_1 + a_2 = b_1)(a_1 + a_3 = b_2),那么考虑枚举 (a_2 + a_3 = b_x),显然 (x leq n)(讲题时说 (x leq n + 3),但是 n 也过了)。
求出 (a_1, a_2, a_3) 后,删去 (b_1, b_2, b_x),那么最小的 (b = a_1 + a_4),计算出 (a_4) 后,删去 (a_4 + a_i), (1 leq i leq 4)

同理可以计算出 (a_5),删除后可以计算出 (a_6)......

一开始删除的时候选择扫一遍 b 数组,但是好像会跑到 (O(n^4)) 过不了,改成二分还是 TLE 了两个点,最后 bitset 优化过掉了...

Code

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <bitset>

using namespace std;

#define N 500

#define fo(i, x, y) for(int i = x; i <= y; i ++)
#define Fo(i, s) for(int i = s; i <= m; i = a[i].r)

void read(int &x) {
    char ch = getchar(); x = 0;
    while (ch < '0' || ch > '9') ch = getchar();
    while (ch >= '0' && ch <= '9') x = (x << 1) + (x << 3) + ch - 48, ch = getchar();
}

int a[N * N + 1];

int c[N + 1], q[N * N + 1], d[N << 3][N + 1];

bitset <N * N + 1> used;

int n, m, tot = 0, ans = 0;

bool Calc(int x, int y, int z) {
    int sum = x + y + z;
    if (sum % 2) return 1;
    sum >>= 1;
    c[1] = sum - z, c[2] = sum - y, c[3] = sum - x;
    return 0;
}

int Find(int x) {
    int l = 1, r = m, mid = 0, k = 0;
    while (l <= r) {
        mid = l + r >> 1;
        a[mid] < x ? l = (k = mid) + 1 : r = mid - 1;
    }
    if (a[ ++ k ] != x) return 0;
    if (! used[k]) k = used._Find_next(k);
    if (a[k] != x) return 0;
    return ! used[k] ? 0 : k;
}

bool Del(int x) {
    int k = Find(x);
    if (! k) return 0;
    used[k] = 0;
    q[ ++ tot ] = k;
    return 1;
}

void Back() {
    int k = q[ tot -- ];
    used[k] = 1;
}

void Dfs(int t) {
    if (t > n) {
        ++ ans;
        fo(i, 1, n) d[ans][i] = c[i];
        return;
    }
    c[t] = a[used._Find_first()] - c[1];
    fo(i, 1, t - 1) if (! Del(c[i] + c[t])) {
        fo(j, 2, i) Back();
        return;
    }
    Dfs(t + 1);
    fo(i, 2, t) Back();
}

int main() {
    freopen("telegram.in", "r", stdin);
    freopen("telegram.out", "w", stdout);

    read(n); m = n * (n - 1) / 2;
    fo(i, 1, m) read(a[i]);

    sort(a + 1, a + 1 + m);
    used.set(); used[0] = 0;
    Del(a[1]), Del(a[2]);
    fo(i, 3, n) {
        if (i > 3 && a[i] == a[i - 1]) continue;
        if (Calc(a[1], a[2], a[i])) continue;
        if (c[1] <= 0 || c[2] <= 0 || c[3] <= 0) continue;
        Del(a[i]);
        Dfs(4);
        Back();
    }

    printf("%d
", ans);
    fo(i, 1, ans) {
        fo(j, 1, n) printf("%d ", d[i][j]);
        puts("");
    }

    return 0;
}
原文地址:https://www.cnblogs.com/zhouzj2004/p/13771061.html