BZOJ 4557 侦查守卫

好迷的树形dp。。。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define maxv 500500
#define maxe 1000500
using namespace std;
int n,m,d,x,y,nume=0,g[maxv],dp1[maxv][23],dp2[maxv][23],w[maxv],fath[maxv];
bool mark[maxv][23];
struct edge
{
    int v,nxt;
}e[maxe];
void addedge(int u,int v)
{
    e[++nume].v=v;
    e[nume].nxt=g[u];
    g[u]=nume;
}
void dp(int x)
{
    int flag=0;
    for (int i=g[x];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if (v!=fath[x])
        {
            fath[v]=x;
            dp(v);flag=1;
            for (int j=1;j<=d;j++)
                mark[x][j]|=mark[v][j-1];
        }
    }
    if (!flag)
    {
        if (mark[x][0]) dp1[x][0]=dp2[x][0]=w[x];
        for (int i=1;i<=d;i++)
        {
            dp2[x][i]=w[x];
            dp1[x][i]=0;
        }
        return;
    }
    for (int i=1;i<=d;i++)
        for (int j=g[x];j;j=e[j].nxt)
        {
             int v=e[j].v;
             if (v!=fath[x])
                dp1[x][i]+=dp1[v][i-1];
        }
    dp2[x][d]+=w[x];
    for (int i=g[x];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if (v!=fath[x])
            dp2[x][d]+=dp1[v][d];
    }
    for (int i=d-1;i>=0;i--)
    {
        int ret=0;
        for (int j=g[x];j;j=e[j].nxt)
        {
             int v=e[j].v;
             if (v!=fath[x])
                ret+=dp1[v][i];
        }
        dp2[x][i]=dp2[x][i+1];
        for (int j=g[x];j;j=e[j].nxt)
        {
             int v=e[j].v;
             if (v!=fath[x])
                dp2[x][i]=min(dp2[x][i],dp2[v][i+1]+ret-dp1[v][i]);
        }
    }
    dp1[x][0]=dp2[x][0];
    for (int i=1;i<=d;i++) 
        dp1[x][i]=min(dp1[x][i],dp1[x][i-1]);
    for (int i=d-1;i>=0;i--) 
        if (!mark[x][i]) dp1[x][i]=min(dp1[x][i],dp1[x][i+1]);
    dp2[x][0]=dp1[x][0];
}
int main()
{
    scanf("%d%d",&n,&d);
    for (int i=1;i<=n;i++)   
        scanf("%d",&w[i]);
    scanf("%d",&m);
    for (int i=1;i<=m;i++)
    {
        scanf("%d",&x);
        mark[x][0]=true;
    }
    for (int i=1;i<=n-1;i++)
    {
        scanf("%d%d",&x,&y);
        addedge(x,y);addedge(y,x);
    }
    dp(1);
    printf("%d
",dp1[1][0]);
    return 0;
}
原文地址:https://www.cnblogs.com/ziliuziliu/p/5776709.html