[ZJOI2019]语言

https://www.luogu.org/problemnew/show/P5327

题解

首先考虑每个点能够到达的集合,它在树上一定是一个联通块。

先考虑枚举每个点,然后算它能到达多少点。

然后我们把所有跨越这个点的链全都拿出来,那么这个联通块的边数=这个点能够访问的点数=这些链的链并。

考虑传统的求链并的方法,就是按照(dfs)序排序,然后答案就是每个点的deep减去相邻两个点的(lca)(deep)再减去所有点的(lca)(deep)

这就是大致思路,再考虑如何维护答案。

用类似树上差分的思想,对于每一条链,在(u)(v)处分别打上(u:+1)(v:+1)的标记,再在lca和(fa[lca])处打上(-1)的标记。

然后用线段树合并维护所有标记的贡献即可。

代码

#include<bits/stdc++.h>
#define N 100009
#define inf 2e9
#define ls tr[cnt].l
#define rs tr[cnt].r
using namespace std;
typedef long long ll;
int lo[N<<1],head[N],tot,tott,deep[N],mp[N],dfn[N],p[20][N<<1],now,fa[N],T[N],n,m,_tag[N];
ll ans;
vector<int>vec[N];
inline ll rd(){
	ll x=0;char c=getchar();bool f=0;
	while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
	while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
	return f?-x:x;
}
struct edge{int n,to;}e[N<<1];
struct seg{
	int l,r,mi,ma,cnt;
	ll val;
}tr[N*80];
inline void add(int u,int v){e[++tot].n=head[u];e[tot].to=v;head[u]=tot;}
inline int maxx(int u,int v){return deep[u]>deep[v]?v:u;}
inline int _min(int u,int v){if(!u||!v)return u|v;return dfn[u]<dfn[v]?u:v;}
inline int _max(int u,int v){if(!u||!v)return u|v;return dfn[u]>dfn[v]?u:v;}
inline int getlca(int u,int v){
	if(!u||!v)return 0;
	u=mp[u];v=mp[v];
	if(u>v)swap(u,v);
	int loo=lo[v-u+1];
	return maxx(p[loo][u],p[loo][v-(1<<loo)+1]);
}
inline void pushup(int cnt){
	tr[cnt].val=tr[ls].val+tr[rs].val-deep[getlca(tr[ls].ma,tr[rs].mi)];
	tr[cnt].ma=_max(tr[ls].ma,tr[rs].ma);
	tr[cnt].mi=_min(tr[ls].mi,tr[rs].mi);
}
void dfs(int u){
    dfn[u]=++dfn[0];_tag[dfn[0]]=u;
	p[0][++now]=u;mp[u]=now;
	for(int i=head[u];i;i=e[i].n)if(e[i].to!=fa[u]){
		int v=e[i].to;deep[v]=deep[u]+1;fa[v]=u;
		dfs(v);p[0][++now]=u;
	}
}
void upd(int &cnt,int l,int r,int x,int tag){
	if(!cnt)cnt=++tott;
	if(l==r){
		tr[cnt].cnt+=tag;
		if(tr[cnt].cnt)
		  tr[cnt].mi=tr[cnt].ma=x,tr[cnt].val=deep[x];
		else tr[cnt].mi=tr[cnt].ma=tr[cnt].val=0;
		return;
	}
	int mid=(l+r)>>1;
	if(dfn[x]<=mid)upd(ls,l,mid,x,tag);
	else upd(rs,mid+1,r,x,tag);
	pushup(cnt);
}
int merge(int u,int v,int l,int r){
	if(!u||!v)return u^v;
	if(l==r){
		tr[u].cnt+=tr[v].cnt;
		tr[u].ma=_max(tr[u].ma,tr[v].ma);
		tr[u].mi=_min(tr[u].mi,tr[v].mi);
		tr[u].val=tr[u].cnt?deep[_tag[l]]:0;
		return u;
	}
	int mid=(l+r)>>1;
	tr[u].l=merge(tr[u].l,tr[v].l,l,mid);
	tr[u].r=merge(tr[u].r,tr[v].r,mid+1,r);
    pushup(u);
    return u; 
}
void work(int u){
	for(int i=head[u];i;i=e[i].n)if(e[i].to!=fa[u]){
	  int v=e[i].to;
	  work(v);T[u]=merge(T[u],T[v],1,n);
	}
	for(vector<int>::iterator it=vec[u].begin();it!=vec[u].end();++it){
		int v=*it;
		upd(T[u],1,n,v,-1);
	}
//	cout<<tr[T[u]].val<<" "; 
	int x=tr[T[u]].val-deep[getlca(tr[T[u]].mi,tr[T[u]].ma)];
	ans+=x;
}
int main(){
	n=rd();m=rd();
	int u,v;
	for(int i=1;i<n;++i){
		u=rd();v=rd();
		add(u,v);add(v,u);
	}
	dfs(1);
	for(int i=2;i<=now;++i)lo[i]=lo[i>>1]+1;
	for(int i=1;(1<<i)<=now;++i)
	  for(int j=1;j+(1<<i)-1<=now;++j)p[i][j]=maxx(p[i-1][j],p[i-1][j+(1<<i-1)]);
	for(int i=1;i<=m;++i){
		u=rd();v=rd();
		int lca=getlca(u,v);
		vec[lca].push_back(u);
		vec[lca].push_back(v);
		vec[fa[lca]].push_back(u);
		vec[fa[lca]].push_back(v);
		upd(T[u],1,n,u,1);upd(T[u],1,n,v,1);
		upd(T[v],1,n,u,1);upd(T[v],1,n,v,1);
	}
    work(1);
    printf("%lld
",ans/2);
	return 0;
}
原文地址:https://www.cnblogs.com/ZH-comld/p/10804737.html