hdu5977 2016年大连区域赛 树上点分治

算是补了去年磊哥拉的题,早上先学了两节课点分治1,然后发现别人总结出的点分治用于求与树上两点之间的路径或者是点对的问题

分析了一下为什么能够通过算出一个点对于问题的贡献之后再无视这个结点,也就是删除了这个结点解决问题,

设两点之间的路径必定经过x点,x点能在树上形成的贡献是以链的形式,那么x点必定作为链的端点或者是经过的点,如果已经计算了x点的贡献,x点作为根节点的子树之间的点之间不可能再有贡献,这样就把一个规模大的树上路径或者点对问题转化成了多个规模小的问题,找点的方式就是找重心还是挺友好的

然后针对这一题的话,网上的做法很多是树上点分治之后套树形背包加容斥,用将重心的贡献计算上的情况除去在重心为根的某一个子树中形成1<<k-1的情况,而早上闲下来的我转化了思想,转化成了树形dp问题,用一个1<<10的桶来保留之前子树对状态st1的贡献,用另外一个桶st2保留当前子树的贡献,然后再dp,奈何搞半天没弄出来,然后dd做出来了,思路其实差不多,分治之后对子树搜索,每一个结点遍历1<<k次来更新贡献

贴上d巨巨被修改过的代码

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
const int N = 5e4 + 10;
int a[N];
int n, k, size[N], rt, tot, st[N], mx[N];
vector<int>g[N];
ll ans;
int num[1100];

void getrt(int u, int fa)
{
    mx[u] = 0; size[u] = 1;
    for (auto &j: g[u])
    {
        if(j == fa || st[j]) continue;
        getrt(j, u);
        size[u] += size[j];
        mx[u] = max(mx[u], size[j]);
    }
    mx[u] = max(mx[u], tot - size[u]);
    if(mx[rt] > mx[u])
        rt = u;
}

void dfs(int u, int fa, int val, int t)
{
    t |= (1 << a[u]);
    num[t] += val;
    for (auto &j : g[u])
    {
        if(j == fa || st[j]) continue;
        dfs(j, u, val, t);
    }
}

void calans(int u, int fa, int t)
{
    t |= (1 << a[u]);
    for (int i = 0; i < (1 << k); i ++)
        if((t|i) == (1<<k) - 1)
            ans += num[i];
    for (auto &j : g[u])
    {
        if(j == fa || st[j]) continue;
        calans(j, u, t);
    }
}

void calc(int u, int fa)
{
    dfs(u, 0, 1, 0);
    ans += num[(1 << k) - 1];
    for (auto &j : g[u])
    {
        if(j == fa || st[j]) continue;
        dfs(j, u, -1, (1 << a[u]));
        calans(j, u, 0);
        dfs(j, u, 1, (1 << a[u]));
    }
    dfs(u, 0, -1, 0);
}

void solve(int u)
{
    st[u] = 1;
    calc(u, 0);
    for (auto &j: g[u])
    {
        if(st[j]) continue;
        rt = 0;
        tot = size[j];
        getrt(j, 0);
        solve(rt);
    }
}

int main()
{
    while(scanf("%d%d", &n, &k) != EOF)
    {
        ans = 0;
        for (int i = 1; i <= n; i ++) g[i].clear(), st[i] = 0;
        for (int i = 1; i <= n; i ++)
        {
            scanf("%d", &a[i]);
            a[i] --;
        }
        for (int i = 1; i <= n - 1; i ++)
        {
            int a, b; scanf("%d%d", &a, &b);
            g[a].push_back(b); g[b].push_back(a);
        }
        rt = 0; tot = mx[rt] = n;
        getrt(1, 0);
        solve(rt);
        printf("%lld
", ans);
    }
}
View Code
原文地址:https://www.cnblogs.com/Urchin-C/p/13756067.html