【学习笔记/题解】虚树/[SDOI2011]消耗战

题目戳我

( ext{Solution:})

题目很显然可以设(dp[i])表示(i)的子树内的关键点都不和(i)联通的最小待机,有如下(dp)方程:

(vin son_u,vin key:dp[u]+=dis(u,v))

(vin son_u,v otin key:dep[u]+=min(dis(u,v),dp[v]))

但是暴力(dp)复杂度(O(nm).)观察(k)的大小发现,每一次都是有很多点不需要(dp)的。

于是我们可以考虑把原树根据所(dp)的关键点浓缩成一棵小一些的树,这便是虚树

虚树中包含的点只有各个关键点及其(LCA,LCA)(LCA)等。

考虑用单调栈维护链来把这个虚树构造出来。

首先处理出(LCA),这里我选择的倍增。然后处理出字典序。

对每一次的关键点,对其按照字典序排序后,把根加入栈中。

对于下一次新加入的点:

  • 先求出当前要入栈的点和栈顶的(LCA).并向后比对。大于它深度的弹出栈,对每一个点在构造虚树的时候把其父亲记录下来。

  • 若当前栈顶元素等于(LCA)则终止弹栈。

  • 若当前栈顶元素的下一个元素深度小于(LCA),则将当前栈顶元素的父亲认为(LCA)并弹出此栈顶,终止弹栈。

  • 终止后,若(LCA)没有进过栈,则将其入栈,并将其父亲认为其入栈前的栈顶元素。

  • 最后,当前元素入栈,认父亲为当前栈顶元素。

  • 用一个数组将所有进过栈的元素储存下来,这便是虚树中的所有元素。又因为之前记录过父亲,这虚树就可以被还原出来。

  • 对数组中的元素按照字典序排序后用非递归的方式进行(dp)即可。

非虚树代码:

#include<bits/stdc++.h>
using namespace std;
const int MAXN=5e5+10;
int dp[MAXN],n,m,h[MAXN],tot,head[MAXN];
struct E{int nxt,to,dis;}e[MAXN];
int vis[MAXN];
inline void add(int x,int y,int w){
	e[++tot]=(E){head[x],y,w};
	head[x]=tot;
}
void dfs(int x,int fa){
	for(int i=head[x];i;i=e[i].nxt){
		int j=e[i].to;
		if(j==fa)continue;
		dfs(j,x);
		if(vis[j])dp[x]+=e[i].dis;
		else dp[x]+=min(dp[j],e[i].dis);
		dp[j]=0; 
	}
}
int main(){
	scanf("%d",&n);
	for(int i=1;i<n;++i){
		int x,y,z;
		scanf("%d%d%d",&x,&y,&z);
		add(x,y,z);add(y,x,z);
	}
	scanf("%d",&m);
	while(m--){
		int x;
		scanf("%d",&x);
		for(int i=1;i<=x;++i){
			scanf("%d",h+i);
			vis[h[i]]=1;
		}
		dfs(1,0);printf("%d
",dp[1]);dp[1]=0;
		for(int i=1;i<=x;++i)vis[h[i]]=0;
	}
	return 0;
}

虚树代码(附注释):

#include<bits/stdc++.h>
using namespace std;
const int MAXN=5e5+10;
const int inf=(1<<30);
char buf[1<<21],*p1=buf,*p2=buf;
inline int read(){
	#define gc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
	char ch=gc();int s=0;
	while(!isdigit(ch))ch=gc();
	while(isdigit(ch))s=s*10-48+ch,ch=gc();
	return s;
}
int tot,id,dfn[MAXN],n,m,st[MAXN],dp[MAXN],M[MAXN][22],pa[MAXN];
int f[MAXN][22],head[MAXN],h[MAXN],top,dep[MAXN],vis[MAXN],L;
struct E{int nxt,to,dis;}e[MAXN];
inline void add(int x,int y,int w){
	e[++tot]=(E){head[x],y,w}; 
	head[x]=tot;
}
int TM,yy[MAXN],val[MAXN];
long long ans[MAXN];
void dfs(int x,int fa){
	dfn[x]=++id,dep[x]=dep[fa]+1,f[x][0]=fa;
	for(int i=1;i<=20;++i){
		f[x][i]=f[f[x][i-1]][i-1];
		M[x][i]=min(M[x][i-1],M[f[x][i-1]][i-1]);
	}
	for(int i=head[x];i;i=e[i].nxt){
		int j=e[i].to;
		if(j==fa)continue;
		M[j][0]=e[i].dis;
		dfs(j,x);
	}
}
int lca(int x,int y){
	if(dep[x]>dep[y])swap(x,y);
	for(int i=21;i>=0;--i)if(dep[x]<=dep[y]-(1<<i))y=f[y][i];
	if(x==y)return x;
	for(int i=21;i>=0;--i)if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];
	return f[x][0];//LCA 
}
bool cmp(int x,int y){return dfn[x]<dfn[y];}
int dis(int x,int y){
	int ans=inf;
	if(dep[x]<dep[y])swap(x,y);
	for(int i=20;i>=0;--i){
		if(dep[f[x][i]]>=dep[y])
			ans=min(ans,M[x][i]),x=f[x][i];
		if(x==y)return ans;
	}
	for(int i=21;i>=0;--i)if(f[x][i]!=f[y][i])ans=min(ans,min(M[x][i],M[y][i])),x=f[x][i],y=f[y][i]; 
	return ans;//求两点间距离最小值 
}
void Build(){
	sort(h+1,h+L+1,cmp);
	int tmp=L;top=0;
	for(int i=1,l;i<=tmp;++i){
		int u=h[i];
		if(!top){
			pa[u]=0;
			st[++top]=u;
			continue;
		}
		int w=lca(st[top],u);
		//当前点和栈顶的LCA 
		while(dep[st[top]]>dep[w]){
			if(dep[st[top-1]]<dep[w])pa[st[top]]=w;//如果下一个点的儿子是w,那么当前点的父亲就是w(维护链,按深度判断) 
			top--;//弹出 
		}
		if(w!=st[top]){
			h[++L]=w;
			pa[w]=st[top];
			st[++top]=w;
			//如果w未进过栈,则其父亲是当前栈顶,h将其记录下,进栈 
		}
		pa[u]=w,st[++top]=u;//u进栈,其父亲是上一个栈顶(w此时必然是栈顶) 
	}
	sort(h+1,h+L+1,cmp);//现在h里面存了所有虚树上的点,按深度排序 
}
void DP(){
	for(int i=1;i<=L;++i)ans[h[i]]=0;//初始化 
	for(int i=L;i>=2;--i){//从深度最大的开始dp,第一个点一定是根 不dp 
		int u=h[i];
		if(vis[u])ans[pa[u]]+=1ll*val[u];//基本dp不解释 
		else ans[pa[u]]+=min(1ll*val[u],ans[u]);
	}
}
long long solve(){
	for(int i=2;i<=L;++i)val[h[i]]=dis(h[i],pa[h[i]]);//预处理dis 倍增处理掉 
	DP();return ans[1];
}
int main(){
	n=read();
	for(int i=1;i<n;++i){
		int x=read(),y=read(),z=read();
		add(x,y,z);add(y,x,z);
	}
	m=read();dfs(1,0);//处理字典序 深度 边权最小值 LCA等信息 
	while(m--){
		L=read()+1;h[1]=1;TM=L-1;//强制第一个点是1 
		for(int j=2;j<=L;++j)h[j]=read(),vis[h[j]]=1,yy[j-1]=h[j];
		Build();printf("%lld
",solve());
		for(int i=1;i<=TM;++i)vis[yy[i]]=0;
	}
	return 0;
} 
原文地址:https://www.cnblogs.com/h-lka/p/13767060.html