虚树简介
我们发现有些树形(dp)题目,会给定几个特殊点,围绕这几个特殊点求贡献
那么我们会发现树中其实存在大量传递点,这些点在树形(dp)的过程中的作用仅仅是把下面某个子树的信息传上去或根本没有信息
那么我们就可以建立出一颗新基于原树的虚树,让虚树仅包含特殊点和这些点的(lca),((lca)可能需要整合多个特殊点信息,不属于传递点),省去传递点的复杂度
虚树构建
我们构建出的虚树应该是原树的精简版,不能破坏原树的形态
我们先将要构建成虚树的点根据(dfs)序排序
然后考虑增量构造法
先用一个栈存储将要连边的点
我们有两个情况分类:
一:新加入的点(x)在(st[top])的子树中
也就是(lca(x,st[top])==st[top]),注意按(dfs)序排列(x)必不可能是(lca)
那么我们直接将(x)压入栈顶,等待弹栈
二:新加入的点(x)不在(st[top])的子树中
由于按(dfs)序排列,那么后面的点不会再有(st[top])的子树中的点了
那么一直弹栈,每次弹栈之前让(st[top-1])向(st[top])连一条边,直到栈顶的深度小于(lca)为止
我们再把(lca)和(x)存入栈中
最后吧栈弹空
可以得出虚树的总点数<=询问总点数*2
那么我们在虚树上树形(dp)而不影响原答案
例题
(n)个点的有根树,(m)次询问,每次给定(k)个点,求至少要切断多少点才能让(k)个点之间两两不连通
(n<=2.5*10^5,m<=5*10^5,sum{k_i}<=5*10^5)
先特判一步,如果父子都是特殊点无解
考虑一个(O(nm))的大力(dp)
设(ret[x])是以(x)为根的子树的答案和,(siz[x])是(x)的子树中存在多少个特殊点,(vis[x])表示(x)是不是特殊点
如果(vis[x]==1),那么这个特殊点下的每个子树都要再切断一个点,返回(ret[x]+siz[x])
如果(siz[x]==0),返回(ret[x])
如果(siz[x]==1),那么(vis[x]=1),返回(ret[x]),表示下面这个特殊点目前没有必要切断,但是需要传递上去
如果(siz[x]>=2),那么就把该点切断,返回(ret[x]+1)
然后我们发现如果建立出虚树,那么复杂度就是(O(sum{k_i}))
#include<bits/stdc++.h>
using namespace std;
namespace red{
#define int long long
#define ls(p) (p<<1)
#define rs(p) (p<<1|1)
#define mid ((l+r)>>1)
inline int read()
{
int x=0;char ch,f=1;
for(ch=getchar();(ch<'0'||ch>'9')&&ch!='-';ch=getchar());
if(ch=='-') f=0,ch=getchar();
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
return f?x:-x;
}
const int N=4e5+10,inf=1e9+7;
int n,m;
vector<int> eg[N];
int head[N],cnt;
struct point
{
int nxt,to;
point(){}
point(const int &nxt,const int &to):nxt(nxt),to(to){}
}a[N<<1];
inline void link(int x,int y)
{
a[++cnt]=(point){head[x],y};head[x]=cnt;
a[++cnt]=(point){head[y],x};head[y]=cnt;
}
int lg[N],dep[N],f[N][21],dfn[N],idx;
int q[N],num;
int st[N],top;
bool vis[N];
inline void dfs1(int now,int fa)
{
f[now][0]=fa;dfn[now]=++idx;
dep[now]=dep[fa]+1;
for(int i=1;i<=20;++i) f[now][i]=f[f[now][i-1]][i-1];
for(int i=head[now];i;i=a[i].nxt)
{
int t=a[i].to;
if(t==fa) continue;
dfs1(t,now);
}
}
inline int getlca(int x,int y)
{
if(dep[x]<dep[y]) swap(x,y);
while(dep[x]^dep[y]) x=f[x][lg[dep[x]-dep[y]]-1];
if(x==y) return x;
for(int i=20;~i;--i)
if(f[x][i]^f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
inline bool cmp(const int &a,const int &b)
{
return dfn[a]<dfn[b];
}
inline void insert(int x)
{
if(top==1) {st[++top]=x;return;}
int lca=getlca(x,st[top]);
if(lca==st[top])
{
st[++top]=x;
return;
}
while(top>1&&dfn[st[top-1]]>=dfn[lca]) eg[st[top-1]].push_back(st[top]),--top;
if(lca!=st[top]) eg[lca].push_back(st[top]),st[top]=lca;
st[++top]=x;
}
inline int dfs(int now)
{
int sum=eg[now].size(),ret=0,siz=0;
for(int i=0;i<sum;++i)
{
int t=eg[now][i];
ret+=dfs(t);
if(vis[t]) ++siz;
}
if(vis[now]) return siz+ret;
if(!siz) return ret;
if(siz>=2) return ret+1;
vis[now]=1;return ret;
}
inline void sol(int now)
{
int sum=eg[now].size();
vis[now]=0;
for(int i=0;i<sum;++i)
{
int t=eg[now][i];
sol(t);
}
eg[now].clear();
}
inline void main()
{
n=read();
for(int x,y,i=1;i<n;++i)
{
x=read(),y=read();
link(x,y);
}
for(int i=1;i<=n;++i) lg[i]=lg[i>>1]+1;
lg[0]=1;
dfs1(1,0);
m=read();
for(int i=1;i<=m;++i)
{
num=read();
for(int j=1;j<=num;++j) q[j]=read(),vis[q[j]]=1;
sort(q+1,q+num+1,cmp);
bool flag=0;
for(int j=1;j<=num;++j)
{
if(vis[f[q[j]][0]])
{
puts("-1");
flag=1;
break;
}
}
if(flag)
{
for(int j=1;j<=num;++j) vis[q[j]]=0;
continue;
}
st[top=1]=1;
for(int j=1;j<=num;++j)
{
if(q[j]^1) insert(q[j]);
}
while(top>0) eg[st[top-1]].push_back(st[top]),--top;
printf("%lld
",dfs(1));
sol(1);
}
}
}
signed main()
{
red::main();
return 0;
}