树形DP 学习笔记

树形DP学习笔记

ps: 本文内容与蓝书一致

树的重心

  • 概念: 一颗树中的一个节点其最大子树的节点树最小
  • 解法:对与每个节点求他儿子的(size) ,上方子树的节点个数为(n-size_u) ,求对于每个节点子树的最大值,找出最小的那个就好了;

(我觉得就不需要code了)


树的直径

  • 概念:一颗带权树的最长路径
  • 解法:维护一个节点到叶子节点的最大距离(d1[i])和次大距离(d2[i]) ,最大距离就是$max {d1[i]+d2[i] } $

code

#include<iostream>
#include<cstdio>
using namespace std;
const int N=1e4+5;
int n;
struct pp
{
    int to,next;
}w[2*N];
int head[N],cnt;
int d1[N],d2[N];
int ans;
void add(int x,int y)
{
    cnt++;
    w[cnt].next=head[x];
    w[cnt].to=y;
    head[x]=cnt;
}
void dfs(int x,int fa)
{
    for(int i=head[x];i;i=w[i].next)
    {
        int t=w[i].to;
        if(t!=fa)
        {
            dfs(t,x);
            if(d1[t]+1>d1[x])
            {
                d2[x]=d1[x];
                d1[x]=d1[t]+1;
            }
            else if(d1[t]+1>d2[x]) d2[x]=d1[t]+1;
        }
    }
    return ;
}
void find_ans(int x,int fa)
{
    ans=max(ans,d1[x]+d2[x]);
    for(int i=head[x];i;i=w[i].next)
    {
        int t=w[i].to;
        if(t!=fa) find_ans(t,x);
    }
    return;
}
int main()
{
#ifndef ONLINE_JUDGE
    freopen("diam.in","r",stdin);
    freopen("diam.out","w",stdout);
#endif
    scanf("%d",&n);
    for(int i=1;i<n;i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        add(x,y);
        add(y,x);
    }
    dfs(1,0);
    find_ans(1,0);
    printf("%d",ans);
    return 0;
}

例题

P4480 逃学的小孩

  • 大概思路:求出树的直径以及其左右端点,再设(d[i])为树上节点(i)到左右端点距离更小的那个,然后求出(max {d[i]}),然后以这个值加上直径就是(ans)

code

#include<iostream>
#include<cstdio>
#include<cstring>
#define ll long long
using namespace std;
const int N=2e5+5;
struct pp
{
    int next,to;
    ll qu;
}w[N*2];
int head[N],cnt;
int n,m;
bool v[N];
ll d1[N],d2[N],dl[N],dr[N];
int f1[N],f2[N];
int r,l;
ll ans,mans;
void add(int x,int y,int z)
{
    w[++cnt].next=head[x];
    w[cnt].qu=z;
    w[cnt].to=y;
    head[x]=cnt;
}
int read()
{
    int f=1;
    char ch;
    while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
    int res=ch-'0';
    while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
    return res*f;
}
void dfs1(int x)
{
    if(v[x]) return ;
    v[x]=1;
    for(int i=head[x];i;i=w[i].next)
    {
        int t=w[i].to;
        if(!v[t])
        {
            dfs1(t);
            if(d1[t]+w[i].qu>d1[x])
            {
                f2[x]=f1[x];
                f1[x]=f1[t];
                d2[x]=d1[x];
                d1[x]=d1[t]+w[i].qu;
            }
            else if(d1[t]+w[i].qu>d2[x]) d2[x]=d1[t]+w[i].qu,f2[x]=f1[t];
        }
        
    }
    return;
}
void find_ans(int x)
{
    if(v[x]) return;
    v[x]=1;
    if(ans<d1[x]+d2[x])
    {
        ans=d1[x]+d2[x];
        l=f1[x];
        r=f2[x];
    }
    for(int i=head[x];i;i=w[i].next) find_ans(w[i].to);
}
void dfs2(int x)
{
    if(v[x]) return;
    v[x]=1;
    for(int i=head[x];i;i=w[i].next)
    {
        int t=w[i].to;
        if(!v[t])
        {
            dl[t]=dl[x]+w[i].qu;
            dfs2(t);
        }
    }
    return;
}
void dfs3(int x)
{
    if(v[x])return;
    v[x]=1;
    
    for(int i=head[x];i;i=w[i].next)
    {
        int t=w[i].to;
        if(!v[t])
        {
            dr[t]=dr[x]+w[i].qu;
            dfs3(t);
        }
    }
    return;
}
void dfs_ans(int x)
{
    if(v[x]) return;
    v[x]=1;
    mans=max(mans,min(dl[x],dr[x]));
    for(int i=head[x];i;i=w[i].next) dfs_ans(w[i].to);
    return;
}
int main()
{
#ifndef ONLINE_JUDGE
    freopen("Chris.in","r",stdin);
    freopen("Chris.out","w",stdout);
#endif
    n=read();
    m=read();
    for(int i=1;i<=m;i++)
    {
        int x,y,z;
        x=read(),y=read(),z=read();
        add(x,y,z);
        add(y,x,z);
    }
    for(int i=1;i<=n;i++) f1[i]=i;
    dfs1(1);
    memset(v,0,sizeof(v));
    find_ans(1);
    memset(v,0,sizeof(v));
    dfs2(l);
    memset(v,0,sizeof(v));
    dfs3(r);
    memset(v,0,sizeof(v));
    dfs_ans(1);
    printf("%lld",ans+mans);
    return 0;
}

树的中心

  • 概念:给出一颗带权树,求一个节点,使得此节点到树中其他节点的最远距离最小;

  • 解法:如果是一颗没有负边权的树,那直接找到直径的中点就好;

    但是这里我们考虑有负边权的情况:

    有两种情况:

    1. (u)点向上的最长路径,设为(up[u]);
    2. (u)点向下,即(u)到叶节点的最远距离,设为(d1[u])(次远设为(d2[u]));

    (d1[u])(d2[u])都会求,问题是(up[u])该怎么求?

    还是分类讨论,设(u)的父亲为(x),(d1[x])来自于子节点(v);那对于(u):

    1. 如果(u!=v),那么(up[u]=max{d1[x],up[x]}+dis[x][t]);
    2. 如果(u==v),那么(up[u]=max{d2[x],up[x]}+dis[x][t]),这也是为什么要维护(d2[x])的原因;

code

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int N=1e5+5;
struct pp
{
    int next,to;
}w[2*N];
int n,k;
int head[N],cnt;
int d1[N],d2[N],pre[N],u[N];
int root,far;
int read()
{
    int f=1;
    char ch;
    while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
    int res=ch-'0';
    while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
    return res*f;
}
void add(int x,int y)
{
    cnt++;
    w[cnt].next=head[x];
    w[cnt].to=y;
    head[x]=cnt;
}
bool cmp(int x,int y) {return x>y;}
void dfs1(int x,int fa)
{
    for(int i=head[x];i;i=w[i].next)
    {
        int t=w[i].to;
        if(t!=fa)
        {
            dfs1(t,x);
            if(d1[t]+1>d1[x])
            {
                pre[x]=t;
                d2[x]=d1[x];
                d1[x]=d1[t]+1;
            }
            else if(d1[t]+1>d2[x]) d2[x]=d1[t]+1;
        }
    }
    return;
}
void dfs2(int x,int fa)
{
    int minx=min(u[x],d1[x]);
    if(far<minx)
    {
        root=x;
        far=minx;
    }
    for(int i=head[x];i;i=w[i].next)
    {
        int t=w[i].to;
        if (t!=fa)
        {
            if(pre[x]!=t) u[t]=max(d1[x],u[x])+1;
            else u[t]=max(d2[x],u[x])+1;
            dfs2(t,x);
        }
    }
    return ;
}
int main()
{
    n=read(),k=read();
    for(int i=1;i<n;i++)
    {
        int x,y;
        x=read(),y=read();
        add(x,y);
        add(y,x);
    }
    dfs1(1,0);
    dfs2(1,0);
    printf("%d",root);
    return 0;
}

例题

P5536核心城市

  • 思路:显然其中一定会有一个城市为这颗树的中心;那找出这个中心,把这颗无根树变为以它为根的有根树;再求出除根节点以外的每个节点所能到达的最大深度(deepfar[i]),这就是这个节点最远所能到达的距离;然后(sort)一下(deepfar[]),答案就是(deepfar[k+1]+1);

code

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int N=1e5+5;
struct pp
{
    int next,to;
}w[2*N];
int n,k;
int head[N],cnt;
int d1[N],d2[N],pre[N],u[N];
int fardeep[N];
int root,far;
int read()
{
    int f=1;
    char ch;
    while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
    int res=ch-'0';
    while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
    return res*f;
}
void add(int x,int y)
{
    cnt++;
    w[cnt].next=head[x];
    w[cnt].to=y;
    head[x]=cnt;
}
bool cmp(int x,int y) {return x>y;}
void dfs1(int x,int fa)
{
    for(int i=head[x];i;i=w[i].next)
    {
        int t=w[i].to;
        if(t!=fa)
        {
            dfs1(t,x);
            if(d1[t]+1>d1[x])
            {
                pre[x]=t;
                d2[x]=d1[x];
                d1[x]=d1[t]+1;
            }
            else if(d1[t]+1>d2[x]) d2[x]=d1[t]+1;
        }
    }
    return;
}
void dfs2(int x,int fa)
{
    for(int i=head[x];i;i=w[i].next)
    {
        int t=w[i].to;
        if (t!=fa)
        {
            if(pre[x]!=t) u[t]=max(d1[x],u[x])+1;
            else u[t]=max(d2[x],u[x])+1;
            dfs2(t,x);
        }
    }
    return ;
}
void dfs3(int x,int fa)
{
    int minx=min(u[x],d1[x]);
    if(far<minx)
    {
        root=x;
        far=minx;
    }
    for(int i=head[x];i;i=w[i].next) if(w[i].to!=fa) dfs3(w[i].to,x);
    return;
}
void dfs4(int x,int fa)
{
    for(int i=head[x];i;i=w[i].next)
    {
        int t=w[i].to;
        if(t!=fa)
        {
            dfs4(w[i].to,x);
            fardeep[x]=max(fardeep[x],fardeep[t]+1);
        }
    }
}
int main()
{
#ifndef ONLINE_JUDGE
    freopen("XR-3.in","r",stdin);
    freopen("XR-3.out","w",stdout);
#endif
    n=read(),k=read();
    for(int i=1;i<n;i++)
    {
        int x,y;
        x=read(),y=read();
        add(x,y);
        add(y,x);
    }
    dfs1(1,0);
    dfs2(1,0);
    dfs3(1,0);
    dfs4(root,0);
    sort(fardeep+1,fardeep+1+n,cmp);
    printf("%d",fardeep[k+1]+1);
    return 0;
}

上面都是有关树的一些经典题型,下面才是今天的主角——树型DP


背包类树型DP

(我觉得把,其实左右子树类树型DP可以归为这一类)

例题

选课

书上的是时间复杂度为(n^3)的算法,这里介绍一个优化,可以讲其降为(n^2);

  • 泛化物品优化:具体是什么,请参考2009年国家集训队论文——徐持衡《浅谈几类背包问题》,其中有详细解释;

  • 而我对泛化物品优化的感性理解就是:"预留空间"——为在 (u) 到到根节点的路径上(包括u)的点预留空间。

    这样就可以在对 (u)DP的时候保证他所依赖的物品预先算进去了

    (dp[u][j])的意思就是在预留(u)及其到根节点的路径上的点的空间后,还剩下(j)的空间的最大价值;

  • 没有优化前,DP方程为:

  • 没有优化前,DP方程为:

[dp[u][j]=max{dp[u][j],dp[u][j-k]+dp[v][k]} ]

这样对于每个节点都要(n^2)暴力枚举(j)(k);

经过优化,我们的DP方程就变为了:

[egin{cases} dp[v][j]=dp[u][j](dfs前)\ dp[u][j]=max{dp[u][j],dp[v][j-w[v]]+val[v]}(回溯时) end{cases} ]

这也是再泛化物品优化下,树型背包的基本DP方程;这样我们只需要(O(n))枚举(j)就好了;


ps: 以下代码参考价值不大,建议参考[HAOI2010]软件安装

code

#include<iostream>
#include<algorithm>
#include<queue>
#include<cstdio>
#include<cstring>
using namespace std;

int n,m;
struct edge
{
    int next,to;
}e[1000];
int rt,head[1000],tot,val[1000],dp[1000][1000];
void add(int x,int y)
{
    e[++tot].next=head[x];
    head[x]=tot;
    e[tot].to=y;
}
void dfs(int u,int t)
{
    if (t<=0) return ;
    for (int i=head[u]; i; i=e[i].next)
    {
        int v = e[i].to;
        for (int j=0; j<=t-1; ++j) //为v预留空间
            dp[v][j] = dp[u][j];
        dfs(v,t-1);//对于v的现有空间
        for (int j=1; j<=t; ++j) 
            dp[u][j] = max(dp[u][j],dp[v][j-1]+val[v]);//背包
    }
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)
    {
        int a;
        scanf("%d%d",&a,&val[i]);
        if(a)
          add(a,i);
        if(!a)add(0,i);
    }
    dfs(0,m);
    printf("%d",dp[0][m]);
}

选择类树型DP

基本DP方程:

[vin{son(u)} egin{cases} f[u][0]=sum f[v][1] \ f[u][1]=min{f[v][1],f[v][0]}+1 end{cases} ]

例题

P2016战略游戏

直接套DP方程就好了;

code

#include<iostream>
#include<cstdio>
using namespace std;
int n;
int dp[1605][2];
struct pp
{
	int next,to;
}w[1600<<1];
int head[1600],cnt;
void add(int x,int y)
{
	cnt++;
	w[cnt].to=y;
	w[cnt].next=head[x];
	head[x]=cnt;
}
void dfs(int x,int fa)
{
	dp[x][1]=1;
	for(int i=head[x];i;i=w[i].next)
	{
		int t=w[i].to;
		if(t==fa) continue;
		dfs(t,x);
		dp[x][0]+=dp[t][1];
		dp[x][1]+=min(dp[t][0],dp[t][1]);
	}
	return;
}
int main()
{
	scanf("%d",&n);
	for(int i=1;i<=n;i++)
	{
		int a,k;
		scanf("%d%d",&a,&k);
		for(int i=1;i<=k;i++)
		{
			int b;
			scanf("%d",&b);
			add(a,b);
			add(b,a);
		}
	}
	dfs(0,0);
	printf("%d",min(dp[0][1],dp[0][0]));
	return 0;
}

普通树型DP

这种树型DP更加灵活,就不像前两种有基本固定的DP方程,所以还是直接来几道例题;(滑稽

例题

LOJ #10157. 皇宫看守

乍一看题,啊哈,模板选择树型DP,开开心心打个代码,恭喜你0分;

仔细一看这道题其实不是什么没有上司的舞会,而是一道覆盖DP题,区别在哪呢?

这道题一条边两端至少要有一个点,可以有两个,而没有上司我舞会是一条边两端至多有一个点,可以没有;

那好,这样的话一个节点u的最少经费就不能像选择DP一样单纯的由儿子选不选的而转移过来,因为他们本来互不冲突,而是必须被覆盖到(这里每个节点的覆盖半径为1),这样对于一个节点u的最少经费就可以由覆盖它的节点转移过来,这样的话就需要考虑三种情况:

首先设(dp[u][0])表示被节点(u)被父亲覆盖且(u)不选,(dp[u][1])表示被自己的子节点覆盖且(u)不选,(dp[u][2])表示被自己覆盖;

所以有状态转移方程:

  • 对于(dp[u][0]),因为(u)不选,所以对于(u)的子节点(v),要么被(son(v))所覆盖,要么被(v)自己覆盖:

[dp[u][0]=sum min{dp[v][1],dp[v][2]} +a[f[u]]; ]

  • 对于(dp[u][1]),要保证(u)必须被一个子节点所覆盖到,还要保证(u)的子节点(v)在不被父亲覆盖的前提下被覆盖到,那显然(dp[u][1]),是由(dp[v][1])(dp[v][2])转移过来的,但是如何保证(dp[u][1])的转移中一定包含(dp[v][2])呢?

    这时候有个巧妙的办法,设个参数:

    [d=min{d,dp[v][2]-min{dp[v][1],dp[v][2]}} ]

    (d)的初始值为(0x7fffffff);

    这样对于(dp[u][1])就有状态转移方程:

    [dp[u][1]=sum min{dp[v][1],dp[v][2]}+d ]

  • 对于(dp[u][2]),那很显然它可以由子节点任意三种状态转移过来,但是对于(dp[v][0]),它已经加过一遍(a[u]),而对于(dp[u][2]),只能且必须加一遍(a[u]),那怎么办呢?单独特判由(dp[v][0])转移过来的情况,控制(a[u])只加一遍?显然是可以的,但是太麻烦了,那么另外考虑,这里可以看到(dp[v][0])只会往(dp[u][2])上转移,那么可以根据(dp[u][2])需求对(dp[v][0])状态转移方程改一改:

    [dp[u][0]=sum min{dp[v][1],dp[v][2]} ]

    (这里的(u)是对于(v)来说的)

    感性理解一下就是如果(dp[u][2])不由(dp[v][0])转移过来那要(dp[v][0])也没有什么用,那由(dp[v][0])转移过来,那在(dp[u][2])这加一遍(a[u])就够了,因为(dp[u][2])已经保证了(u)被选,所以不需要(dp[v][0])再保证一遍;

    这样对于(dp[u][2]),就有状态转移方程:

    [dp[u][2]=sum min{dp[v][1],dp[v][2],dp[v][0]} +a[u] ]

总结下来就有三个状态转移方程:

[egin{cases} dp[u][0]=sum min{dp[v][1],dp[v][2]};\ dp[u][1]=sum min{dp[v][1],dp[v][2]}+d ;(d=min{d,dp[v][2]-min{dp[v][1],dp[v][2]}})\ dp[u][2]=sum min{dp[v][1],dp[v][2],dp[v][0]} +a[u] end{cases} ]

(所以,显然书上的状态转移方程是错的)

不难发现,修改后的(dp[v][0])一定小于等于(dp[v][1]);所以写代码的时候我顺手把(dp[u][2])的转移方程改成了:

[dp[u][2]=sum min{dp[v][2],dp[v][0]} +a[u] ]

虽然题目早已经解决了,但我还是想再深究一下;这个方程啥意思?

以我的感性理解就是(v)既然已经一定会被它爹(u)覆盖到,那就可以不需要保证(v)一定被它的儿子所覆盖,修改后的(dp[v][0])刚好就是这种情况;

(好了,bb了这么多废话,就一点有用的东西,直接上代码吧)

code

#include <iostream>
#include <cstdio>
using namespace std;
const int N = 1500 + 5;
int dp[N][3];
int v[N], n, root;
struct pp {
    int next, to;
} w[N];
int head[N], cnt, du[N];
void add(int x, int y) {
    cnt++;
    w[cnt].next = head[x];
    w[cnt].to = y;
    head[x] = cnt;
}
void dfs(int x) {
    int d = 0x7fffffff;
    for (int i = head[x]; i; i = w[i].next) {
        int t = w[i].to;
        dfs(t);
        dp[x][0] += min(dp[t][1], dp[t][2]);
        dp[x][1] += min(dp[t][1], dp[t][2]);
        d = min(d, dp[t][2] - min(dp[t][1], dp[t][2]));
        dp[x][2] += min(dp[t][2], dp[t][0]);
    }
    dp[x][1] += d;
    dp[x][2] += v[x];
}
int main() {
#ifndef ONLINE_JUDGE
    freopen("guard.in", "r", stdin);
    freopen("guard.out", "w", stdout);
#endif
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        int x, m;
        scanf("%d", &x);
        scanf("%d", &v[x]);
        scanf("%d", &m);
        for (int j = 1; j <= m; j++) {
            int y;
            scanf("%d", &y);
            add(x, y);
            du[y]++;
        }
    }
    for (int i = 1; i <= n; i++)
        if (!du[i])
            root = i;
    dfs(root);
    printf("%d", min(dp[root][1], dp[root][2]));
    return 0;
}

好了,差不多就结束了,虽然写这个一点耗时,但对于我这个蒟蒻来说加深了对于DP的理解,收获也不小,也不算浪费时间了吧(逃);


PS: 2020.10.9 添加了我对泛化物品优化的理解

原文地址:https://www.cnblogs.com/Wednesday-zfz/p/12209729.html