作用
虚树常常被使用在树形 (dp)中。
有些时候,我们需要计算的节点仅仅是一棵树中的某几个节点
这个时候如果对整棵树都进行一次计算开销太大了
所以我们需要把这些节点从原树中抽象出来
按照它们在原树中的关系重新建一棵树,这样的树就是虚树
构建方法
在构建之前,我们需要把所有需要加入的节点按照 (dfn) 序从小到大排好序
在加点时,我们要用栈维护一个最右链
在这个链左边的虚树都已经构建完成
我们设 (top) 为栈顶,设要加入的节点为 (now),设栈顶元素与 (now) 的 (LCA) 为 (lc)
在加入的时候,会有以下几种情况
(1)、(lc=sta[top])
此时我们直接把 (now) 接在最右链之后即可
(2)、(lc) 位于 (sta[top]) 和 (sta[top-1])之间
此时 (sta[tp]) 已经不在最右链上,将其在虚树上和 (lc) 连边后出栈
同时把 (lc) 和 (now) 依次入栈
(3)、(lc) 为 (sta[top-1])
和上面几乎一样,只是不把 (lc) 入栈
(4)、(lc) 的深度比 (sta[top-1]) 还小
我们把 (sta[top]) 和 (sta[top-1]) 连边后出栈,重复之前的操作
这样,我们直接在建出来的虚树上 (dp) 就可以了
设总点数为 (k),则时间复杂度为 (O(klogk))
代码实现
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#define rg register
inline int read(){
rg int x=0,fh=1;
rg char ch=getchar();
while(ch<'0' || ch>'9'){
if(ch=='-') fh=-1;
ch=getchar();
}
while(ch>='0' && ch<='9'){
x=(x<<1)+(x<<3)+(ch^48);
ch=getchar();
}
return x*fh;
}
const int maxn=1e6+5;
int h[maxn],tot=1,h2[maxn],t2=1;
struct asd{
int to,nxt,val;
}b[maxn],b2[maxn];
void ad(rg int aa,rg int bb,rg int cc){
b[tot].to=bb;
b[tot].val=cc;
b[tot].nxt=h[aa];
h[aa]=tot++;
}
void ad2(rg int aa,rg int bb){
b2[t2].to=bb;
b2[t2].nxt=h2[aa];
h2[aa]=t2++;
}
int n,m,fa[maxn],dep[maxn],son[maxn],siz[maxn];
long long mindis[maxn];
void dfs1(rg int now,rg int lat){
fa[now]=lat;
dep[now]=dep[lat]+1;
siz[now]=1;
for(rg int i=h[now];i!=-1;i=b[i].nxt){
rg int u=b[i].to;
if(u==lat) continue;
mindis[u]=std::min(mindis[now],1LL*b[i].val);
dfs1(u,now);
siz[now]+=siz[u];
if(son[now]==0 || siz[u]>siz[son[now]]) son[now]=u;
}
}
int dfn[maxn],dfnc,tp[maxn],stk[maxn],cnt,sta[maxn],js;
void dfs2(rg int now,rg int top){
tp[now]=top;
dfn[now]=++dfnc;
if(son[now]) dfs2(son[now],top);
for(rg int i=h[now];i!=-1;i=b[i].nxt){
rg int u=b[i].to;
if(u==son[now] || u==fa[now]) continue;
dfs2(u,u);
}
}
bool cmp(rg int aa,rg int bb){
return dfn[aa]<dfn[bb];
}
int get_lca(rg int u,rg int v){
while(tp[u]!=tp[v]){
if(dep[tp[u]]<dep[tp[v]]) std::swap(u,v);
u=fa[tp[u]];
}
if(dep[u]<dep[v]) return u;
else return v;
}
void init(rg int now){
rg int lca=get_lca(now,sta[js]);
while(1){
if(dfn[lca]>=dfn[sta[js-1]]){
if(lca!=sta[js]){
ad2(sta[js],lca);
ad2(lca,sta[js]);
if(lca!=sta[js-1]){
sta[js]=lca;
} else {
js--;
}
}
break;
} else {
ad2(sta[js],sta[js-1]);
ad2(sta[js-1],sta[js]);
js--;
}
}
sta[++js]=now;
}
bool vis[maxn];
long long dfs(rg int now,rg int lat){
rg long long ans=0,cs=0;
for(rg int i=h2[now];i!=-1;i=b2[i].nxt){
rg int u=b2[i].to;
if(u==lat) continue;
ans+=dfs(u,now);
}
if(vis[now]){
cs=mindis[now];
} else {
cs=std::min(mindis[now],ans);
}
vis[now]=0;
h2[now]=-1;
return cs;
}
int main(){
memset(h,-1,sizeof(h));
memset(h2,-1,sizeof(h2));
memset(mindis,0x7f,sizeof(mindis));
n=read();
rg int aa,bb,cc;
for(rg int i=1;i<n;i++){
aa=read(),bb=read(),cc=read();
ad(aa,bb,cc);
ad(bb,aa,cc);
}
dfs1(1,0);
dfs2(1,1);
sta[0]=1;
m=read();
for(rg int i=1;i<=m;i++){
cnt=read();
t2=1;
for(rg int j=1;j<=cnt;j++){
aa=read();
stk[j]=aa;
vis[aa]=1;
}
std::sort(stk+1,stk+cnt+1,cmp);
sta[js=1]=stk[1];
for(rg int j=2;j<=cnt;j++){
init(stk[j]);
}
while(js>0){
ad2(sta[js],sta[js-1]);
ad2(sta[js-1],sta[js]);
js--;
}
printf("%lld
",dfs(1,0));
}
return 0;
}