Hdu 6268 点分治 树上背包 bitset 优化

给你一颗大小为n(3000)的树,树上每个点有点权(100000),再给你一个数m(100000)

i为1~m,问树中是否存在一个子图,使得权值为i.

每次solve到一个节点 用一个bitset维护所有经过它的链的取值(calc前要先初始化当前节点的bitset)

复杂度为nlognm/64

#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
#define mst(a,b) memset((a),(b),sizeof(a))
#define mp(a,b) make_pair(a,b)
#define pi acos(-1)
#define pii pair<int,int>
#define pb push_back
const int INF = 0x3f3f3f3f;
const double eps = 1e-6;
const int maxn = 3e3 + 10;
const int maxm = 1e5 + 10;
const ll mod =  998244353;

int n,m;
vector<int>vec[maxn];
bool used[maxn];
int a[maxn],root,sz[maxn],son[maxn],all;

void getroot(int u,int fa) {
    sz[u] = 1, son[u] = 0;
    for(int i = 0; i < vec[u].size(); i++) {
        int v = vec[u][i];
        if(used[v] || v == fa) continue;
        getroot(v,u);
        sz[u] += sz[v];
        son[u] = max(son[u],sz[v]);
    }
    son[u] = max(son[u],all - son[u]);
    if(son[u] < son[root]) root = u;
}

bitset<maxm>bit[maxn],ans;

void calc(int u,int fa) {
    sz[u] = 1, bit[u] <<= a[u];
    for(int i = 0; i < vec[u].size(); i++) {
        int v = vec[u][i];
        if(used[v] || v == fa) continue;
        bit[v] = bit[u];
        calc(v,u);
        sz[u] += sz[v];
        bit[u] |= bit[v];
    }
}

void solve(int u) {
    used[u] = true;
    bit[u].reset(), bit[u].set(0);
    calc(u,0);
    ans |= bit[u];
    for(int i = 0; i < vec[u].size(); i++) {
        int v = vec[u][i];
        if(used[v]) continue;
        root = 0;
        all = sz[v];
        getroot(v,0);
        solve(root);
    }
}

int main() {
#ifdef local
    freopen("data.txt", "r", stdin);
//    freopen("data.txt", "w", stdout);
#endif
    int t;
    scanf("%d",&t);
    while(t--) {
        ans.reset();
        scanf("%d%d",&n,&m);
        for(int i = 0; i <= n; i++) vec[i].clear(),used[i] = false;
        for(int i = 1; i < n; i++) {
            int u,v;
            scanf("%d%d",&u,&v);
            vec[u].push_back(v);
            vec[v].push_back(u);
        }
        for(int i = 1; i <= n; i++) scanf("%d",&a[i]);
        root = 0;
        son[0] = 1e9;
        all = n;
        getroot(1,0);
        solve(root);
        for(int i = 1; i <= m; i++) printf("%d",(int)ans[i]);
        printf("
");
    }
    return 0;
}

 

原文地址:https://www.cnblogs.com/Aragaki/p/10590513.html