虚树学习笔记

作用

虚树常常被使用在树形 (dp)中。

有些时候,我们需要计算的节点仅仅是一棵树中的某几个节点

这个时候如果对整棵树都进行一次计算开销太大了

所以我们需要把这些节点从原树中抽象出来

按照它们在原树中的关系重新建一棵树,这样的树就是虚树

构建方法

在构建之前,我们需要把所有需要加入的节点按照 (dfn) 序从小到大排好序

在加点时,我们要用栈维护一个最右链

在这个链左边的虚树都已经构建完成

我们设 (top) 为栈顶,设要加入的节点为 (now),设栈顶元素与 (now)(LCA)(lc)

在加入的时候,会有以下几种情况

(1)(lc=sta[top])

此时我们直接把 (now) 接在最右链之后即可
(2)(lc) 位于 (sta[top])(sta[top-1])之间

此时 (sta[tp]) 已经不在最右链上,将其在虚树上和 (lc) 连边后出栈

同时把 (lc)(now) 依次入栈

(3)(lc)(sta[top-1])

和上面几乎一样,只是不把 (lc) 入栈

(4)(lc) 的深度比 (sta[top-1]) 还小

我们把 (sta[top])(sta[top-1]) 连边后出栈,重复之前的操作

这样,我们直接在建出来的虚树上 (dp) 就可以了

设总点数为 (k),则时间复杂度为 (O(klogk))

代码实现

P2495 [SDOI2011]消耗战为例

#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#define rg register
inline int read(){
	rg int x=0,fh=1;
	rg char ch=getchar();
	while(ch<'0' || ch>'9'){
		if(ch=='-') fh=-1;
		ch=getchar();
	}
	while(ch>='0' && ch<='9'){
		x=(x<<1)+(x<<3)+(ch^48);
		ch=getchar();
	}
	return x*fh;
}
const int maxn=1e6+5;
int h[maxn],tot=1,h2[maxn],t2=1;
struct asd{
	int to,nxt,val;
}b[maxn],b2[maxn];
void ad(rg int aa,rg int bb,rg int cc){
	b[tot].to=bb;
	b[tot].val=cc;
	b[tot].nxt=h[aa];
	h[aa]=tot++;
}
void ad2(rg int aa,rg int bb){
	b2[t2].to=bb;
	b2[t2].nxt=h2[aa];
	h2[aa]=t2++;
}
int n,m,fa[maxn],dep[maxn],son[maxn],siz[maxn];
long long mindis[maxn];
void dfs1(rg int now,rg int lat){
	fa[now]=lat;
	dep[now]=dep[lat]+1;
	siz[now]=1;
	for(rg int i=h[now];i!=-1;i=b[i].nxt){
		rg int u=b[i].to;
		if(u==lat) continue;
		mindis[u]=std::min(mindis[now],1LL*b[i].val);
		dfs1(u,now);
		siz[now]+=siz[u];
		if(son[now]==0 || siz[u]>siz[son[now]]) son[now]=u;
	}
}
int dfn[maxn],dfnc,tp[maxn],stk[maxn],cnt,sta[maxn],js;
void dfs2(rg int now,rg int top){
	tp[now]=top;
	dfn[now]=++dfnc;
	if(son[now]) dfs2(son[now],top);
	for(rg int i=h[now];i!=-1;i=b[i].nxt){
		rg int u=b[i].to;
		if(u==son[now] || u==fa[now]) continue;
		dfs2(u,u);
	}
}
bool cmp(rg int aa,rg int bb){
	return dfn[aa]<dfn[bb];
}
int get_lca(rg int u,rg int v){
	while(tp[u]!=tp[v]){
		if(dep[tp[u]]<dep[tp[v]]) std::swap(u,v);
		u=fa[tp[u]];
	}
	if(dep[u]<dep[v]) return u;
	else return v;
}
void init(rg int now){
	rg int lca=get_lca(now,sta[js]);
	while(1){
		if(dfn[lca]>=dfn[sta[js-1]]){
			if(lca!=sta[js]){
				ad2(sta[js],lca);
				ad2(lca,sta[js]);
				if(lca!=sta[js-1]){
					sta[js]=lca;
				} else {
					js--;
				}
			}
			break;
		} else {
			ad2(sta[js],sta[js-1]);
			ad2(sta[js-1],sta[js]);
			js--;
		}
	}
	sta[++js]=now;
}
bool vis[maxn];
long long dfs(rg int now,rg int lat){
	rg long long ans=0,cs=0;
	for(rg int i=h2[now];i!=-1;i=b2[i].nxt){
		rg int u=b2[i].to;
		if(u==lat) continue;
		ans+=dfs(u,now);
	}
	if(vis[now]){
		cs=mindis[now];
	} else {
		cs=std::min(mindis[now],ans);
	}
	vis[now]=0;
	h2[now]=-1;
	return cs;
}
int main(){
	memset(h,-1,sizeof(h));
	memset(h2,-1,sizeof(h2));
	memset(mindis,0x7f,sizeof(mindis));
	n=read();
	rg int aa,bb,cc;
	for(rg int i=1;i<n;i++){
		aa=read(),bb=read(),cc=read();
		ad(aa,bb,cc);
		ad(bb,aa,cc);
	}
	dfs1(1,0);
	dfs2(1,1);
	sta[0]=1;
	m=read();
	for(rg int i=1;i<=m;i++){
		cnt=read();
		t2=1;
		for(rg int j=1;j<=cnt;j++){
			aa=read();
			stk[j]=aa;
			vis[aa]=1;
		}
		std::sort(stk+1,stk+cnt+1,cmp);
		sta[js=1]=stk[1];
		for(rg int j=2;j<=cnt;j++){
			init(stk[j]);
		}
		while(js>0){
			ad2(sta[js],sta[js-1]);
			ad2(sta[js-1],sta[js]);
			js--;
		}
		printf("%lld
",dfs(1,0));
	}
	return 0;
}
原文地址:https://www.cnblogs.com/liuchanglc/p/14184719.html