SGU 134 Centroid

简单树型DP

题意:一个图,N个点,N-1条边,并且说明是树(一棵树,连森林都排除掉了)。在这颗树中删除一个点rt以及和他关联的边,那么剩下的部分将会是森林,统计森林中每棵树的节点数,最大值记录下来为dp[rt],你的任务是找出最小的dp[rt],如果有多个rt相等,那么按升序输出所有rt的编号

一个最大值最小的问题,解法是求出所有的dp[rt]然后记录最小值

因为本来是一棵树,用哪个做顶点都可以,默认为1,为整个树的祖先。

定义sum[rt] = 以rt为根的子树含有的节点数

     dp[rt] = dp[rt]:除去rt后,各个分块中节点数的最大值

     ans = min{ dp[rt] };

    删掉rt后,其孩子将作为某些子树的树根,选出一个最大值 max{ sum[son] }

   另外还有一棵子树是以整个树的祖先作为树根的子树,节点数为 sum[R]-sum[rt]

  所以 dp[rt]=max { sum[R]-sum[rt]  ,  max{sum[son]}  }

计算sum[]遍历一次树即可,在得到sum[]后再dfs一次整棵树即可。

隐式建树即可

#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
using namespace std;
#define N 16010
#define INF 0x3f3f3f3f

vector<int>a[N];
int n;
int sum[N]; //sum[rt]:以rt为根的子树有多少个节点
int dp[N];  //dp[rt]:除去rt后,各个分块中节点数的最大值
//ans = min{ dp[rt] };
int ans;
int res[N],resnum;
bool vis[N];

void travel(int rt)
{
   vis[rt]=true;
   sum[rt]=0;
   int size=a[rt].size();
   for(int i=0; i<size; i++)
   {
      int son=a[rt][i];
      if(!vis[son])
      {
         travel(son);
         sum[rt] += sum[son];
      }
   }
   ++sum[rt];
}

void dfs(int rt ,int R)
{
   vis[rt]=true;
   dp[rt]=0;
   int size=a[rt].size();
   for(int i=0; i<size; i++)
   {
      int son=a[rt][i];
      if(!vis[son])
      {
         dfs(son,R);
         dp[rt] = max(dp[rt] , sum[son]);
      }
   }
   dp[rt] = max(dp[rt] , sum[R]-sum[rt]);
   if(dp[rt] == ans)
   {
      res[resnum]=rt;
      resnum++;
   }
   if(dp[rt] < ans)
   {
      ans = dp[rt];
      resnum=1;
      res[0]=rt;
   }
}

void solve()
{
   memset(vis,false,sizeof(vis));
   for(int i=1; i<=n; i++)
      if(!vis[i])
      {
         travel(i);
      }
//   for(int i=1; i<=n; i++)
//      printf("%d:  %d\n",i,sum[i]);
   memset(vis,false,sizeof(vis));
   ans=INF;
   resnum=0;
   for(int i=1; i<=n; i++)
      if(!vis[i])
         dfs(i,i);

   sort(res,res+resnum);
   printf("%d %d\n",ans,resnum);
   for(int i=0; i<resnum; i++)
   {
      if(!i) printf("%d",res[i]);
      else   printf(" %d",res[i]);
   }
}

int main()
{
   scanf("%d",&n);
   for(int i=1; i<=n; i++)
      a[i].clear();
   for(int i=0; i<n-1; i++)
   {
      int u,v;
      scanf("%d%d",&u,&v);
      a[u].push_back(v);
      a[v].push_back(u);
   }
   solve();
   return 0;
}
原文地址:https://www.cnblogs.com/scau20110726/p/3015677.html