UOJ#11. 【UTR #1】ydc的大树

LXXXII.UOJ#11. 【UTR #1】ydc的大树

很明显,如果我们令一个黑点\(x\)为树根,设它的“好朋友”集合为\(\mathbb{S}\),则路径\((x,\operatorname{LCA}\{\mathbb{S}\})\)中所有白节点均可以使\(x\)不开心。这个可以用树上差分来进行路径加。现在关键是求出\(\operatorname{LCA}\{\mathbb{S}\}\)

我们采取二次扫描与换根法。第一遍扫描,我们求出一个节点子树中所有黑点到它的距离的最大值,以及这些离它最远的黑点的\(\operatorname{LCA}\)。我们用一个std::pair<int,int>来储存这两个值,记作\(f_x\)

我们考虑怎么求出\(f_x\)出来——

设有一条边是\((x,y,z)\)。则\(f_x\)会从f[y].first+z最大的那个\(y\)转移过来;但是,如果存在两个不同的\(y\)都是最大值,显然它们的\(\operatorname{LCA}\)就是\(x\)本身。

这是求\(f_x\)的代码(如果不存在这个黑点,\(f_x\)即为\((-1,0)\))。

void dfs1(int x,int fa){
	if(bla[x])f[x]=make_pair(0,x);else f[x]=make_pair(-1,0);
	for(auto i:v[x]){
		if(i.first==fa)continue;
		dfs1(i.first,x);
		if(f[i.first].first!=-1)f[x].first=max(f[x].first,f[i.first].first+i.second);
	}
	if(f[x].first==-1)return;
	int cnt=0;
	for(auto i:v[x]){
		if(i.first==fa)continue;
		if(f[i.first].first==-1)continue;
		if(f[i.first].first+i.second==f[x].first)cnt++,f[x].second=f[i.first].second;
	}
	if(cnt>1)f[x].second=x;
}

既然要二次扫描,我们自然要设一个\(g_x\),表示\(x\)子树外所有节点到\(x\)的最大距离及它们的\(\operatorname{LCA}\)

\(g_x\)可以从这些东西转移过来:

  1. 兄弟们的\(f_y\)

  2. 父亲的\(g_y\)

  3. 父亲自身(假如父亲是黑点的话)

我们把前两个东西丢进vector中按照first从大到小排序。选取最大值(当然不能是\(f_x\)自身)转移即可。

当然,如果前两个东西中没有任何黑点,要考虑从父亲自身转移。

这部分的代码:

void dfs2(int x,int fa){
	vector<pair<int,int> >u;
	for(auto i:v[x])if(i.first!=fa&&f[i.first].first!=-1)u.push_back(make_pair(f[i.first].first+i.second,f[i.first].second));
	if(g[x]!=make_pair(-1,0))u.push_back(g[x]);
	sort(u.rbegin(),u.rend());
	for(auto i:v[x]){
		if(i.first==fa)continue;
		for(auto j:u){
			if(j.second==f[i.first].second)continue;
			if(g[i.first]==make_pair(0,0)){g[i.first]=j;continue;}
			if(g[i.first].first==j.first)g[i.first].second=x;
			break;
		}
		if(g[i.first]==make_pair(0,0))g[i.first]=(bla[x]?make_pair(i.second,x):make_pair(-1,0));
		else g[i.first].first+=i.second;
		dfs2(i.first,x);
	}
}

最终就是答案统计了。对于所有的黑点,如果f[x].first==g[x].first,显然\(\operatorname{LCA}\)\(x\)本身,可以忽略;否则,选取f[x]g[x]中较大的那个的second,进行树上差分即可。

复杂度\(O(n\log n)\)(瓶颈在于树上差分,求\(g_x\)部分的那个排序其实没有必要,但是如果这样写会更加清晰)。

代码:

#include<bits/stdc++.h>
using namespace std;
int n,m,anc[100100][20],dep[100100],sum[100100],mx,cnt;
vector<pair<int,int> >v[100100];
pair<int,int>f[100100],g[100100];//first:the maximum route length; second:the lca of all the 'good friends'
bool bla[100100];
void dfs1(int x,int fa){
	if(bla[x])f[x]=make_pair(0,x);else f[x]=make_pair(-1,0);
	for(auto i:v[x]){
		if(i.first==fa)continue;
		dfs1(i.first,x);
		if(f[i.first].first!=-1)f[x].first=max(f[x].first,f[i.first].first+i.second);
	}
	if(f[x].first==-1)return;
	int cnt=0;
	for(auto i:v[x]){
		if(i.first==fa)continue;
		if(f[i.first].first==-1)continue;
		if(f[i.first].first+i.second==f[x].first)cnt++,f[x].second=f[i.first].second;
	}
	if(cnt>1)f[x].second=x;
}
void dfs2(int x,int fa){
	vector<pair<int,int> >u;
	for(auto i:v[x])if(i.first!=fa&&f[i.first].first!=-1)u.push_back(make_pair(f[i.first].first+i.second,f[i.first].second));
	if(g[x]!=make_pair(-1,0))u.push_back(g[x]);
	sort(u.rbegin(),u.rend());
	for(auto i:v[x]){
		if(i.first==fa)continue;
		for(auto j:u){
			if(j.second==f[i.first].second)continue;
			if(g[i.first]==make_pair(0,0)){g[i.first]=j;continue;}
			if(g[i.first].first==j.first)g[i.first].second=x;
			break;
		}
		if(g[i.first]==make_pair(0,0))g[i.first]=(bla[x]?make_pair(i.second,x):make_pair(-1,0));
		else g[i.first].first+=i.second;
		dfs2(i.first,x);
	}
}
void dfs3(int x,int fa){
	anc[x][0]=fa,dep[x]=dep[fa]+1;
	for(auto i:v[x])if(i.first!=fa)dfs3(i.first,x);
}
void dfs4(int x,int fa){
	for(auto i:v[x])if(i.first!=fa)dfs4(i.first,x),sum[x]+=sum[i.first];
}
int LCA(int x,int y){
	if(dep[x]>dep[y])swap(x,y);
	for(int i=19;i>=0;i--)if(dep[x]<=dep[y]-(1<<i))y=anc[y][i];
	if(x==y)return x;
	for(int i=19;i>=0;i--)if(anc[x][i]!=anc[y][i])x=anc[x][i],y=anc[y][i];
	return anc[x][0];
}
int main(){
	scanf("%d%d",&n,&m);
	for(int x;m--;)scanf("%d",&x),bla[x]=true;
	for(int i=1,x,y,z;i<n;i++)scanf("%d%d%d",&x,&y,&z),v[x].push_back(make_pair(y,z)),v[y].push_back(make_pair(x,z));
	g[1]=make_pair(-1,0);
	dfs1(1,0),dfs2(1,0),dfs3(1,0);
	for(int j=1;j<=19;j++)for(int i=1;i<=n;i++)anc[i][j]=anc[anc[i][j-1]][j-1];
	for(int i=1;i<=n;i++){
//		printf("%d:(%d,%d),(%d,%d)\n",i,f[i].first,f[i].second,g[i].first,g[i].second);
		if(!bla[i]||f[i].first==g[i].first)continue;
		int x,y=i;
		if(f[i].first>g[i].first)x=f[i].second;
		else x=g[i].second;
		int lca=LCA(x,y);
		sum[x]++,sum[y]++,sum[lca]--;
		if(anc[lca][0])sum[anc[lca][0]]--;
	}
	dfs4(1,0);
	for(int i=1;i<=n;i++)if(!bla[i])mx=max(mx,sum[i]);
	for(int i=1;i<=n;i++)if(!bla[i])cnt+=(sum[i]==mx);
	printf("%d %d",mx,cnt);
	return 0;
}

原文地址:https://www.cnblogs.com/Troverld/p/14598499.html