bzoj 3572 [Hnoi2014]世界树——虚树

题目:https://www.lydsy.com/JudgeOnline/problem.php?id=3572

关于虚树:https://www.cnblogs.com/zzqsblog/p/5560645.html

构造方法:

  先把关键点按 dfs 序排序,然后依次插入树中;

  插入当前点 cr 的时候,求 lca = get_lca( cr , sta[top] ) ;如果 dep[ sta[top] ] >= dep[lca] ,就一直弹栈;

  弹栈结束后,看看现在的 sta[ top ] 是不是就是 lca 了,如果不是,就 sta[ ++ top ] = lca ;同时 fa[ sta[top+1] ] = lca , fa[ lca ] = sta[ top ] ;

  把 cr 也加入栈中,即 sta[++top] = cr , fa[ cr ] = lca 。

sta[ 1 ] 就是虚树的根。

关于这道题:http://hzwer.com/6804.html

建好虚树,先换根 dp 得出虚树上的每个点应该被哪个点控制。换根的时候不用去掉该子树的贡献,因为不会有影响。

然后枚举虚树上的每条边 ( cr , fa ),用倍增在边上找到最浅的应该 “被控制 cr 的点控制” 的点 v ,然后 siz[v] - siz[cr] 和 siz[tv] - siz[v] 分别贡献即可,其中 tv 是 fa 在 v 方向的直接孩子。

关于不是虚树的点也不在虚树边上的那些点,自己的方法是在换根 dp 的时候处理;那个时候枚举孩子 v 的时候通过找 tv ,可以知道每个虚树上的点 cr 的不在虚树上的孩子们的 siz 和,直接贡献给控制 cr 的那个点即可。

#include<cstdio>
#include<cstring>
#include<algorithm>
#define mkp make_pair
#define fir first
#define sec second
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
const int N=3e5+5,K=18;
int n,hd[N],xnt,to[N<<1],nxt[N<<1];
int dfn[N],dep[N],pre[N][K+5],bin[K+5],lg[N],siz[N],sm[N];
pair<int,int> dp[N];
int m,ans[N],tt,sta[N],tot,fa[N],h2[N],xt2,t2[N<<1],nt2[N<<1];
struct Node{int v,id;}q[N];

bool cmp(Node x,Node y){return dfn[x.v]<dfn[y.v];}
void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;}
void ad2(int x,int y){t2[++xt2]=y;nt2[xt2]=h2[x];h2[x]=xt2;}
void ini_dfs(int cr,int fa)
{
  dfn[cr]=++tot; dep[cr]=dep[fa]+1;
  pre[cr][0]=fa; siz[cr]=1;
  for(int t=1,u;(u=pre[pre[cr][t-1]][t-1]);t++)
    pre[cr][t]=u;
  for(int i=hd[cr],v;i;i=nxt[i])
    if((v=to[i])!=fa)
      {
    ini_dfs(v,cr); siz[cr]+=siz[v];
      }
}
int get_lca(int x,int y)
{
  if(dep[x]<dep[y])swap(x,y);
  for(int t=lg[dep[x]-dep[y]];t>=0;t--)
    if(dep[pre[x][t]]>=dep[y])
      x=pre[x][t];
  if(x==y)return x;
  for(int t=lg[dep[x]];t>=0;t--)
    if(pre[x][t]!=pre[y][t])
      x=pre[x][t],y=pre[y][t];
  return pre[x][0];
}
void build()
{
  sort(q+1,q+m+1,cmp); tt=m; tot=0;
  for(int i=1;i<=m;i++)
    {
      int cr=q[i].v;
      if(!tot){sta[++tot]=cr;dp[cr]=mkp(0,cr);continue;}
      int lca=get_lca(cr,sta[tot]);
      while(dep[sta[tot]]>dep[lca])tot--;
      if(sta[tot]!=lca)
    {
      q[++tt].v=lca; fa[sta[tot+1]]=lca;
      fa[lca]=sta[tot]; sta[++tot]=lca;
      dp[lca]=mkp(N,0);
    }
      fa[cr]=lca; sta[++tot]=cr; dp[cr]=mkp(0,cr);
    }
  for(int i=1;i<=tt;i++)if(fa[q[i].v])ad2(fa[q[i].v],q[i].v);
}
void dfs(int cr,int fa)
{
  for(int i=h2[cr],v;i;i=nt2[i])
    if((v=t2[i])!=fa)
      {
    dfs(v,cr);
    int tp1=dp[v].fir+dep[v]-dep[cr],tp2=dp[cr].fir;
    if(tp1<tp2||(tp1==tp2&&dp[v].sec<dp[cr].sec))
      dp[cr].fir=tp1,dp[cr].sec=dp[v].sec;
      }
}
int fnd2(int cr,int fa)
{
  int d=dep[cr]-dep[fa]-1;
  while(d)
    {
      int lbt=(d&-d);
      cr=pre[cr][lg[lbt]]; d-=lbt;
    }
  return cr;
}
void dfsx(int cr,int fa)
{
  int tp1=dp[fa].fir+dep[cr]-dep[fa],tp2=dp[cr].fir;
  if(fa&&(tp1<tp2||(tp1==tp2&&dp[fa].sec<dp[cr].sec)))
    dp[cr].fir=tp1,dp[cr].sec=dp[fa].sec;
  int s=siz[cr];
  for(int i=h2[cr],v;i;i=nt2[i])
    if((v=t2[i])!=fa)
      {
    int tv=fnd2(v,cr);
    s-=siz[tv]; dfsx(v,cr);
      }
  sm[dp[cr].sec]+=s-1;//-1 for zj
}
int fnd(int cr,int fa)
{
  bool fg=(dp[cr].sec<dp[fa].sec);
  int x=dp[cr].fir,y=dp[fa].fir,d1=dep[cr],d2=dep[fa];
  for(int t=lg[dep[cr]-dep[fa]];t>=0;t--)
    {
      int d=dep[pre[cr][t]];
      int u=d1-d+x,v=d-d2+y;
      if(u<v||(u==v&&fg))cr=pre[cr][t];
    }
  return cr;
}
bool In(int cr,int fa){return dfn[cr]>=dfn[fa]&&dfn[cr]<dfn[fa]+siz[fa];}
void solve()
{
  for(int i=1;i<=m;i++)sm[q[i].v]=0;
  int rt=sta[1]; dfs(rt,0); dfsx(rt,0);
  for(int i=1;i<=tt;i++)sm[dp[q[i].v].sec]++;
  for(int i=1;i<=tt;i++)
    {
      int cr=q[i].v,f=fa[cr];
      if(!f) {sm[dp[cr].sec]+=n-siz[cr];continue;}
      int v=fnd(cr,f),tv=fnd2(cr,f);
      sm[dp[cr].sec]+=siz[v==f?tv:v]-siz[cr];//
      sm[dp[f].sec]+=siz[tv]-siz[v==f?tv:v];//tv//
    }
  for(int i=1;i<=m;i++)ans[q[i].id]=sm[q[i].v];
  for(int i=1;i<=m;i++)printf("%d ",ans[i]); puts("");
}
int main()
{
  n=rdn();
  for(int i=1,u,v;i<n;i++)
    u=rdn(),v=rdn(),add(u,v),add(v,u);
  bin[0]=1;for(int i=1;i<=K;i++)bin[i]=bin[i-1]<<1;
  for(int i=2;i<=n;i++)lg[i]=lg[i>>1]+1;
  ini_dfs(1,0); int Q=rdn();
  while(Q--)
    {
      for(int i=1;i<=tt;i++)h2[q[i].v]=0; xt2=0;
      for(int i=1;i<=tt;i++)fa[q[i].v]=0;///
      m=rdn(); for(int i=1;i<=m;i++)q[i].v=rdn(),q[i].id=i;
      build(); solve();
    }
  return 0;
}
原文地址:https://www.cnblogs.com/Narh/p/10367190.html