[BZOJ3611][HEOI2014]大工程

BZOJ
Luogu

sol

很显然的虚树DP呀。
树上任意两点距离之和?其实只要考虑每一条边被计算了多少次即可,若这条边下方的关键点(也就是选出的那些点)数量为(i),那么这条边的计算次数就是(i*(k-i))
然后最大最小值,直接对每个点记子树中所有关键点到它的最长/最短距离即可。注意初值与这个点是不是关键点有关。
我一开始很傻很天真地以为最短距离一定是dfs序相邻的两个点然后就。。。
总体上还是挺好写的。

code

#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
#define ll long long
int gi()
{
    int x=0,w=1;char ch=getchar();
    while ((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
    if (ch=='-') w=0,ch=getchar();
    while (ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
    return w?x:-x;
}
const int N = 4e6+5;
int n,Q,to[N],nxt[N],val[N],head[N],cnt;
int fa[N],dep[N],sz[N],son[N],top[N],dfn[N],low[N];
int k,len,tp,s[N],q[N],mark[N],f[N],g[N];
int Max,Min;ll Sum;
void link(int u,int v,int w){to[++cnt]=v;nxt[cnt]=head[u];val[cnt]=w;head[u]=cnt;}
void dfs1(int u,int f)
{
	fa[u]=f;dep[u]=dep[f]+1;sz[u]=1;
	for (int e=head[u];e;e=nxt[e])
	{
		int v=to[e];if (v==f) continue;
		dfs1(v,u);
		sz[u]+=sz[v];if (sz[v]>sz[son[u]]) son[u]=v;
	}
}
void dfs2(int u,int up)
{
	top[u]=up;dfn[u]=++cnt;
	if (son[u]) dfs2(son[u],up);
	for (int e=head[u];e;e=nxt[e])
		if (to[e]!=fa[u]&&to[e]!=son[u])
			dfs2(to[e],to[e]);
	low[u]=cnt;
}
int getlca(int u,int v)
{
	while (top[u]^top[v])
	{
		if (dep[top[u]]<dep[top[v]]) swap(u,v);
		u=fa[top[u]];
	}
	return dep[u]<dep[v]?u:v;
}
bool cmp_dfn(int u,int v){return dfn[u]<dfn[v];}
void dp(int u)
{
	sz[u]=mark[u]?1:0;
	f[u]=mark[u]?0:-1e9;
	g[u]=mark[u]?0:1e9;
	for (int e=head[u];e;e=nxt[e])
	{
		int v=to[e];dp(v);
		sz[u]+=sz[v];Sum+=1ll*sz[v]*(k-sz[v])*val[e];
		Max=max(Max,f[u]+f[v]+val[e]);
		f[u]=max(f[u],f[v]+val[e]);
		Min=min(Min,g[u]+g[v]+val[e]);
		g[u]=min(g[u],g[v]+val[e]);
	}
}
int main()
{
	n=gi();
	for (int i=1;i<n;++i)
	{
		int u=gi(),v=gi();
		link(u,v,0);link(v,u,0);
	}
	dfs1(1,0);cnt=0;dfs2(1,1);
	Q=gi();memset(head,0,sizeof(head));
	while (Q--)
	{
		k=len=gi();tp=cnt=0;
		for (int i=1;i<=k;++i) mark[s[i]=gi()]=1;
		sort(s+1,s+k+1,cmp_dfn);
		for (int i=1;i<k;++i) s[++len]=getlca(s[i],s[i+1]);
		sort(s+1,s+len+1,cmp_dfn);len=unique(s+1,s+len+1)-s-1;
		for (int i=1;i<=len;++i)
		{
			while (tp&&low[q[tp]]<dfn[s[i]]) --tp;
			link(q[tp],s[i],dep[s[i]]-dep[q[tp]]);
			q[++tp]=s[i];
		}
		Sum=Max=0;Min=1e9;
		dp(s[1]);
		printf("%lld %d %d
",Sum,Min,Max);
		for (int i=1;i<=len;++i) mark[s[i]]=head[s[i]]=0;
	}
	return 0;
}
原文地址:https://www.cnblogs.com/zhoushuyu/p/8485398.html