BZOJ3068 : 小白树

枚举每条树边,将其断开,那么两侧肯定取带权重心最优。

考虑如何求出每个子树的重心,枚举其所有儿子,通过重量关系就可以判断出重心位于哪棵子树。

然后将那棵子树的重心暴力往上爬即可,因为每个点作为重心肯定是一段连续的链,所以复杂度为$O(n)$。

然后就是如何求出砍掉每棵子树之后剩下的部分的重心。

设当前点到根的路径为关键路径,那么可以通过二分求出重心在关键路径上哪个点的子树里。

对于那棵子树,重心要么是它本身,要么在它最重的子树里,要么在次重的子树里。

在线段树上按dfs序维护区间内子树重量的最大值,即可用线段树完成重心的查询,时间复杂度$O(log n)$。

总时间复杂度$O(nlog n)$。

#include<cstdio>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N=500010,BUF=12000000;
char Buf[BUF],*buf=Buf;
int n,i,x,y,w[N],g[N],v[N<<1],nxt[N<<1],ed;
int f[N],d[N],size[N],son[N],top[N],st[N],en[N],id[N],dfn,q[N],cq;
int fir[N],sec[N],center[N],vip[N];
ll sum[N],sw[N],val[1050000],sd[N],su[N],ans=1LL<<60;
inline void read(int&a){for(a=0;*buf<48;buf++);while(*buf>47)a=a*10+*buf++-48;}
inline void add(int x,int y){v[++ed]=y;nxt[ed]=g[x];g[x]=ed;}
void dfs(int x){
  size[x]=1;
  for(int i=g[x];i;i=nxt[i])if(v[i]!=f[x]){
    f[v[i]]=x,d[v[i]]=d[x]+1;
    dfs(v[i]),size[x]+=size[v[i]];
    if(size[v[i]]>size[son[x]])son[x]=v[i];
  }
}
void dfs2(int x,int y){
  id[st[x]=++dfn]=x;top[x]=y;
  if(son[x])dfs2(son[x],y);
  for(int i=g[x];i;i=nxt[i])if(v[i]!=son[x]&&v[i]!=f[x])dfs2(v[i],v[i]);
  en[x]=dfn;
}
inline int dis(int x,int y){
  int t=d[x]+d[y];
  for(;top[x]!=top[y];x=f[top[x]])if(d[top[x]]<d[top[y]])swap(x,y);
  if(d[x]>d[y])swap(x,y);
  return t-2*d[x];
}
inline void cal(int x){
  int i=center[fir[x]];
  ll t=sum[fir[x]]+sd[x]-sd[fir[x]]-sw[fir[x]]+(sw[x]-sw[fir[x]])*(d[i]-d[x]);
  while(2*sw[i]<sw[x])t+=2*sw[i]-sw[x],i=f[i];
  center[x]=i;
  sum[x]=t;
}
void dfs3(int x){
  sw[x]=w[x];
  for(int i=g[x];i;i=nxt[i]){
    int y=v[i];
    if(y==f[x])continue;
    dfs3(y);
    sw[x]+=sw[y];
    if(sw[y]>sw[fir[x]])sec[x]=fir[x],fir[x]=y;
    else if(sw[y]>sw[sec[x]])sec[x]=y;
    sd[x]+=sd[y]+sw[y];
  }
  if(2*sw[fir[x]]<=sw[x]){
    center[x]=x;
    sum[x]=sd[x];
    return;
  }
  cal(x);
}
void dfs4(int x){
  if(f[x]){
    int y=f[x];
    su[x]=su[y]+sd[y]-sd[x]-2*sw[x]+sw[1];
  }
  for(int i=g[x];i;i=nxt[i])if(v[i]!=f[x])dfs4(v[i]);
}
inline int lower(ll x){
  int l=2,r=cq,mid,t=1;
  while(l<=r)if(2*(sw[q[mid=(l+r)>>1]]-x)>=sw[1]-x)l=(t=mid)+1;else r=mid-1;
  return q[t];
}
void build(int x,int a,int b){
  if(a==b){val[x]=sw[id[a]]*2;return;}
  int mid=(a+b)>>1;
  build(x<<1,a,mid),build(x<<1|1,mid+1,b);
  val[x]=max(val[x<<1],val[x<<1|1]);
}
int ask(int x,int a,int b,int c,int d,ll p){
  if(val[x]<p)return 0;
  if(a==b)return a;
  int mid=(a+b)>>1,t=0;
  if(d>mid)t=ask(x<<1|1,mid+1,b,c,d,p);
  if(t)return t;
  if(c<=mid)t=ask(x<<1,a,mid,c,d,p);
  return t;
}
void dfs5(int x){
  if(f[x])vip[x]=lower(sw[x]);
  q[++cq]=x;
  for(int i=g[x];i;i=nxt[i])if(v[i]!=f[x])dfs5(v[i]);
  cq--;
}
inline void solve(int x){
  int t=vip[x],y,z=0;
  if(st[fir[t]]<=st[x]&&en[x]<=en[fir[t]])y=sec[t];else y=fir[t];
  if(y)z=ask(1,1,n,st[y],en[y],sw[1]-sw[x]);
  if(!z)z=t;else z=id[z];
  ans=min(ans,sum[x]+sd[z]+su[z]-sd[x]-sw[x]*dis(x,z));
}
int main(){
  fread(Buf,1,BUF,stdin);read(n);
  for(i=1;i<n;i++)read(x),read(y),add(x,y),add(y,x);
  for(i=1;i<=n;i++)read(w[i]);
  dfs(1);dfs2(1,1);
  dfs3(1);dfs4(1);
  build(1,1,n);dfs5(1);
  for(i=2;i<=n;i++)solve(i);
  return printf("%lld",ans),0;
}

  

原文地址:https://www.cnblogs.com/clrs97/p/5842291.html