codeforces1029 E.Tree with Small Distances

题目链接

题意:给出一个 (n) 个结点的树,问如何选择结点进行连线使得结点(1)到其他所有结点的最短距离都小于等于 (2)

题解:这道题倒是自己想出来了。

首先有个结论:连线都是从结点1向其他结点连线,因为这样总是最优的。

由题意可知,当结点(1)向某个节点 (u) 连线后,与结点 (u) 直接相连的所有结点都能满足条件。

考虑树形dp。

(d[u][0]):结点u的子结点都满足条件,但是结点u不满足

(d[u][1]):结点u的子树(包括 (u) 自己)都满足条件,但结点u没有被连线

(d[u][2]):结点u的子树(包括 (u) 自己)都满足条件,且结点u与节点1连线

那么,对于叶子结点:

(d[u][0]=0,d[u][1]=inf,d[u][2]=1)

对于非叶子结点 (u) 和它的子结点(v):

(d[u][0]=sum d[v][1];)

(d[u][2]=sum min(d[v][0],min(d[v][1],d[v][2]))+1)

(d[u][1]=sum min(d[v][1],d[v][2]);) //在这个式子中必须保证有至少一个子结点(v) 取的是 (d[v][2])

代码:

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
#include<queue>
#include<stack>
using namespace std;
#define rep(i,a,n) for (int i=a;i<n;i++)
#define per(i,a,n) for (int i=n-1;i>=a;i--)
#define pb push_back
#define fi first
#define se second
#define dbg(...) cerr<<"["<<#__VA_ARGS__":"<<(__VA_ARGS__)<<"]"<<endl;
typedef vector<int> VI;
typedef long long ll;
typedef pair<int,int> PII;
const int inf=0x3fffffff;
const ll mod=1000000007;
const int maxn=2e5+10;
int head[maxn],dep[maxn];
int tol;
int d[maxn][3];
struct edge
{
    int to,next;
}e[maxn*2];

void add(int u,int v)
{
    e[++tol].to=v,e[tol].next=head[u],head[u]=tol;
    e[++tol].to=u,e[tol].next=head[v],head[v]=tol;
}
int cnt[maxn]; //节点度数
//d[u][0]-结点u的子结点都满足条件,但是结点u不满足
//d[u][1]-结点u的子树(包括u自己)都满足条件,但结点u没有被连线
//d[u][2]-结点u的子树(包括u自己)都满足条件,且结点u与节点1连线
void dfs(int u,int f)
{
    dep[u]=dep[f]+1;
    d[u][0]=0;
    d[u][1]=cnt[u]==1? 1e6:0;
    d[u][2]=1;
    int mi=1e9;
    for(int i=head[u];i;i=e[i].next)
    {
        int v=e[i].to;
        if(v==f) continue;
        dfs(v,u);
        d[u][0]+=d[v][1];
        d[u][1]+=min(d[v][1],d[v][2]);
        mi=min(mi,d[v][2]-d[v][1]);
        d[u][2]+=min(d[v][0],min(d[v][1],d[v][2]));
        rep(i,0,3) if(d[u][i]>1e6) d[u][i]=1e6;
    }
    if(mi>0) d[u][1]+=mi;
    rep(i,0,3) if(d[u][i]>1e6) d[u][i]=1e6;
}

int main()
{
    int n;
    scanf("%d",&n);
    rep(i,1,n)
    {
        int u,v;
        scanf("%d%d",&u,&v);
        cnt[u]++,cnt[v]++;
        add(u,v);
    }
    dfs(1,0);
    int ans=0;
    rep(i,1,n+1) if(dep[i]==3) ans+=min(d[i][0],min(d[i][1],d[i][2]));
    printf("%d
",ans);
    return 0;
}
原文地址:https://www.cnblogs.com/tarjan/p/9562592.html