【XSY1545】直径 虚树 DP

题目大意

​  给你一棵(n)个点的树,另外还有(m)棵树,第(i)棵树与原树的以(r_i)为根的子树形态相同。这(m)棵树之间也有连边,组成一颗大树。求这棵大树的直径长度。

  (n,mleq 300000)

题解

​  我们先用DP求出以原树的第(i)个点为根的子树的直径,那么以原树的第(i)个点为根的子树中的某个点为一个端点的最长路的另一个端点一定在直径的一端。

​  然后我们遍历第(i)棵树与其他树之间的边,求出每个点走到其他树的最长路。然后用虚树把这些边在第(i)棵树的端点和(r_i)连在一起,用DP合并。

​  zjt大爷:明明可以直接bfs两次为什么要DP?

​  时间复杂度:(O((n+m)log n))

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
const int maxn=1200010;
struct list
{
	int v[maxn];
	int t[maxn];
	int h[maxn];
	int n;
	list()
	{
		n=0;
		memset(h,0,sizeof h);
	}
	void add(int x,int y)
	{
		n++;
		v[n]=y;
		t[n]=h[x];
		h[x]=n;
	}
};
struct list2
{
	int x[maxn];
	int y[maxn];
	int x2[maxn];
	int y2[maxn];
	int t[maxn];
	int h[maxn];
	int n;
	list2()
	{
		n=0;
		memset(h,0,sizeof h);
	}
	void add(int a,int a2,int b,int b2)
	{
		n++;
		x[n]=a;
		x2[n]=a2;
		y[n]=b;
		y2[n]=b2;
		t[n]=h[a];
		h[a]=n;
	}
};
int n,m;
namespace tree
{
	list l;
	struct p3
	{
		int d,x,y;
		p3(int a=0,int b=0,int c=0)
		{
			d=a;
			x=b;
			y=c;
		}
	};
	int operator <(p3 a,p3 b)
	{
		return a.d<b.d;
	}
	p3 a[maxn];
	int d[maxn];
	int f[maxn];
	int st[600010][21];
	int g[maxn];
	int g2[maxn];
	int ti;
	int bg[maxn];
	int ed[maxn];
	int w[maxn];
	int o[maxn];
	int dmin(int x,int y)
	{
		return d[x]<d[y]?x:y;
	}
	void init()
	{
		ti=0;
		memset(f,0,sizeof f);
	}
	void dfs(int x,int fa,int dep)
	{
		bg[x]=++ti;
		w[ti]=x;
		d[x]=dep;
		f[x]=fa;
		g[x]=1;
		g2[x]=x;
		a[x]=p3(1,x,x);
		int i;
		for(i=l.h[x];i;i=l.t[i])
			if(l.v[i]!=fa)
			{
				dfs(l.v[i],x,dep+1);
				a[x]=max(a[x],p3(g[x]+g[l.v[i]],g2[x],g2[l.v[i]]));
				a[x]=max(a[x],a[l.v[i]]);
				if(g[l.v[i]]+1>g[x])
				{
					g[x]=g[l.v[i]]+1;
					g2[x]=g2[l.v[i]];
				}
				ti++;
				w[ti]=x;
			}
		ed[x]=ti;
	}
	void buildst()
	{
		int i,j;
		for(i=1;i<=ti;i++)
			st[i][0]=w[i];
		for(j=1;j<=20;j++)
			for(i=1;i+(1<<j)-1<=ti;i++)
				st[i][j]=dmin(st[i][j-1],st[i+(1<<(j-1))][j-1]);
		o[1]=0;
		for(i=2;i<=ti;i++)
			o[i]=o[i/2]+1;
	}
	int getlca(int x,int y)
	{
		x=bg[x];
		y=bg[y];
		if(x>y)
			swap(x,y);
		int t=o[y-x+1];
		return dmin(st[x][t],st[y-(1<<t)+1][t]);
	}
	int getdist(int x,int y)
	{
		int lca=getlca(x,y);
		return d[x]+d[y]-2*d[lca]+1;
	}
}
list2 l2;
int r[maxn];
ll ans;
ll f[maxn];
int stack[maxn];
int top;
int c[maxn];
int par[maxn];
ll g[maxn];
int cnt;
int tag[maxn];
int dfscmp(int x,int y)
{
	return tree::bg[x]<tree::bg[y];
}
int getmaxdist(int x,int y)
{
	return max(tree::getdist(y,tree::a[x].x),tree::getdist(y,tree::a[x].y));
}
void pushup(int x)
{
	int v=par[x];
	if(!v)
		return;
	ans=max(ans,g[v]+g[x]+tree::d[x]-tree::d[v]-1);
	g[v]=max(g[v],g[x]+tree::d[x]-tree::d[v]);
}
void dp(int x,int x2,int fa)
{
	f[x]=getmaxdist(r[x],x2);
	int i;
	for(i=l2.h[x];i;i=l2.t[i])
		if(l2.y[i]!=fa)
			dp(l2.y[i],l2.y2[i],x);
	cnt=0;
	c[++cnt]=x2;
	for(i=l2.h[x];i;i=l2.t[i])
		if(l2.y[i]!=fa)
		{
			c[++cnt]=l2.x2[i];
			ans=max(ans,getmaxdist(r[x],l2.x2[i])+f[l2.y[i]]);
			f[x]=max(f[x],f[l2.y[i]]+tree::getdist(x2,l2.x2[i]));
			if(tag[l2.x2[i]]==x)
			{
				ans=max(ans,g[l2.x2[i]]+f[l2.y[i]]);
				g[l2.x2[i]]=max(g[l2.x2[i]],f[l2.y[i]]+1);
			}
			else
			{
				tag[l2.x2[i]]=x;
				g[l2.x2[i]]=f[l2.y[i]]+1;
			}
		}
	if(tag[x2]!=x)
	{
		tag[x2]=x;
		g[x2]=1;
	}
	sort(c+1,c+cnt+1,dfscmp);
	cnt=unique(c+1,c+cnt+1)-c-1;
	top=0;
	int lca;
	for(i=1;i<=cnt;i++)
	{
		if(!top)
		{
			stack[++top]=c[i];
			par[c[i]]=0;
			continue;
		}
		int lca=tree::getlca(c[i],c[i-1]);
		while(top&&tree::d[stack[top]]>tree::d[lca])
		{
			if(tree::d[stack[top-1]]<tree::d[lca])
			{
				par[lca]=stack[top-1];
				par[stack[top]]=lca;
				g[lca]=1;
			}
			pushup(stack[top]);
			top--;
		}
		if(tree::d[stack[top]]<tree::d[lca])
		{
			par[lca]=stack[top];
			stack[++top]=lca;
		}
		par[c[i]]=stack[top];
		stack[++top]=c[i];
	}
	while(top)
	{
		pushup(stack[top]);
		top--;
	}
}
int main()
{
	memset(tag,0,sizeof tag);
	freopen("diameter.in","r",stdin);
	freopen("diameter.out","w",stdout);
	tree::init();
	scanf("%d%d",&n,&m);
	int i,x,y,x2,y2;
	for(i=1;i<n;i++)
	{
		scanf("%d%d",&x,&y);
		tree::l.add(x,y);
		tree::l.add(y,x);
	}
	tree::dfs(1,0,1);
	tree::buildst();
	for(i=1;i<=m;i++)
		scanf("%d",&r[i]);
	for(i=1;i<m;i++)
	{
		scanf("%d%d%d%d",&x,&x2,&y,&y2);
		l2.add(x,x2,y,y2);
		l2.add(y,y2,x,x2);
	}
	ans=0;
	dp(1,r[1],0);
	printf("%lld
",ans);
	return 0;
}
原文地址:https://www.cnblogs.com/ywwyww/p/8511204.html