题解「ZJOI2019 语言」

题意简述:求对于树上每个点 (x) ,包含它的链的并集的大小之和,也可描述成,求对于树上每个点 (x) ,它能够到达的点的个数之和。

不难发现,对于点 (x) 而言,通过树上的路径,它能够到达的点一定构成一棵树。并且这棵树上一定含有包含 (x) 点的 (s_i,t_i) 。那么也就是说,链并大小就是包含一些关键点 (s_i,t_i) 的极小连通子树 (T) 的边数。

问题转化到这里,有一个非常经典的结论,包含一些关键点 (a_1,a_2,a_3...a_n) 的极小连通子树的边数为 (|T_e|=sum_i dep_i-sum_{i=2}^n dep_{ ext{lca}(a_i,a_{i-1})}-dep_{ ext{lca}(a_1,a_2...a_n)}),其中 (a_1,a_2,a_3..a_n)( exttt{dfs}) 序从小到大排列。

那么对于每一个点 (x),找出包含 (x) 的所有路径,并且根据上式求出极小连通子树的边数即可,但是这样做,处理 (x) 点被哪些路径覆盖就是 (O(mn log^2 n)) 的,难以承受。

考虑如何快速统计所有对 (x) 有影响的路径的贡献,可以想到树上差分,对于 (s_i,t_i) 进行 (+1),对于 ( ext{lca}(s_i,t_i)) 进行 (-1), 对于 (fa_{ ext{lca}(s_i,t_i)}) 进行 (-1)。对于每一个点 (x),我们用桶来存储 (m) 条路径对第 (x) 个点的覆盖情况。这么统计点 (x) 被哪些路径覆盖就是 (O(nm)) 的。

观察一下,我们发现 ( ext{dfs}) 时将儿子的桶与父亲的桶合并,很多位置是空的,没必要统计。并且每次都要重新暴力计算一遍最小连通子树 (T) 的边数,显然不是最优的。

不妨把桶换成线段树。点 (x) 的线段树中,区间 ([l,r]) 表示被选中的关键点 ( ext{dfs})(in[l,r]) 时,极小连通子树 (T) 的边数。再维护两个量 (mx,mn) 表示当前区间 ([l,r]) 内被选中的关键点 ( exttt{dfs}) 序的最大值与最小值,(sum) 表示当前区间中被选中的关键点构成的极小连通子树 (T) 的边数,在叶子节点上存储一个 (cnt) 统计 ([l,l]) 的贡献,相当于之前的桶。稍微维护一下,总时间复杂度为 (O(mn log n))

这样的时间复杂度仍然可以优化。想一想,根据差分,父亲节点的线段树,一定与儿子节点线段树的信息是重合的,和之前的桶向上合并一样,我们将儿子节点和父亲节点的线段树合并,使用均摊时间复杂度为 (O(n log n)) 的线段树合并即可。

若使用倍增/树剖 ( ext{LCA}),总时间复杂度为 (O(n log^2 n+m log n));使用 (O(nlog n)-O(1))(LCA) ,总时间复杂度为 (O(n log n+m log n))

Show the Code

#include<cstdio>
typedef long long ll;
/*------------------------Normal I/O&handmade STL--------------------------*/ 
inline int read() {
	register int x=0,f=1;register char s=getchar();
	while(s>'9'||s<'0') {if(s=='-') f=-1;s=getchar();}
	while(s>='0'&&s<='9') {x=x*10+s-'0';s=getchar();}
	return x*f; 
} 
inline void swap(int &x,int &y) {int tmp=y;y=x;x=tmp;} 
/*------------------------Tree--------------------------*/ 
int cnt=0,num=0,tot=0;
int dep[100005],dfn[100005],rev[100005];
int h[100005],to[200005],ver[200005],f[100005][25];
inline void AddEdge(int x,int y) {to[++cnt]=y;ver[cnt]=h[x];h[x]=cnt;}
inline void prework(int x) {
	int fa=f[x][0];dfn[x]=++num;rev[num]=x;
	for(register int i=1;i<=20;++i) f[x][i]=f[f[x][i-1]][i-1];
	for(register int i=h[x];i;i=ver[i]) {
		int y=to[i];if(y==fa) continue;
		dep[y]=dep[x]+1;f[y][0]=x;prework(y); 
	}
}
inline int LCA(int x,int y) {
	if(!x||!y) return 0; 
	if(dep[x]>dep[y]) swap(x,y);//dep[x]<=dep[y]
	for(register int i=20;i>=0;--i) if(dep[x]<=dep[f[y][i]]) y=f[y][i];
	if(x==y) return x;
	for(register int i=20;i>=0;--i) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
	return f[x][0];
}
/*------------------------SegmentTree--------------------------*/
struct Segment {int mn,mx,cnt;ll sum;}t[4000005];
int lson[4000005],rson[4000005],rt[100005];
inline void pushup(int p) {
	t[p].mn=t[lson[p]].mn? t[lson[p]].mn:t[rson[p]].mn;
	t[p].mx=t[rson[p]].mx? t[rson[p]].mx:t[lson[p]].mx;
	t[p].sum=t[lson[p]].sum+t[rson[p]].sum-dep[LCA(rev[t[lson[p]].mx],rev[t[rson[p]].mn])];//???
}
inline void modify_add(int &p,int l,int r,int dfnId,int val) {
	if(!p) p=++tot;
	if(l==r) {
		t[p].cnt+=val;
		t[p].mx=t[p].mn=(t[p].cnt>0? dfnId:0);
		t[p].sum=(t[p].cnt>0? dep[rev[dfnId]]:0);
		return;
	}
	int mid=l+r>>1;
	if(dfnId<=mid) modify_add(lson[p],l,mid,dfnId,val);
	else modify_add(rson[p],mid+1,r,dfnId,val);
	pushup(p);
}
inline int merge(int x,int y,int l,int r) {
	if(!x||!y) return x|y;
	if(l==r) {t[x].cnt+=t[y].cnt;t[x].mx=t[x].mn=(t[x].cnt>0? l:0);t[x].sum=(t[x].cnt>0? dep[rev[l]]:0);return x;}
	int mid=l+r>>1;
	lson[x]=merge(lson[x],lson[y],l,mid);
	rson[x]=merge(rson[x],rson[y],mid+1,r);
	pushup(x); return x;
}
/*------------------------Solution--------------------------*/
ll ans=0;
inline void PathAdd(int x,int y,int dfnId) {
	int z=LCA(x,y),fa=f[z][0];
	modify_add(rt[x],1,num,dfn[dfnId],1);
	modify_add(rt[y],1,num,dfn[dfnId],1);
	modify_add(rt[z],1,num,dfn[dfnId],-1);
	if(fa) modify_add(rt[fa],1,num,dfn[dfnId],-1);
}
inline void solve(int x) {
	int fa=f[x][0];
	for(register int i=h[x];i;i=ver[i]) {int y=to[i];if(y==fa) continue;solve(y);}
	ans+=t[rt[x]].sum-dep[LCA(rev[t[rt[x]].mn],rev[t[rt[x]].mx])]; 
	if(fa) rt[fa]=merge(rt[fa],rt[x],1,num);
}
int main() {
	int n=read(),m=read();
	for(register int i=1;i<n;++i) {int x=read(),y=read();AddEdge(x,y);AddEdge(y,x);} dep[1]=1;prework(1);
	for(register int i=1;i<=m;++i) {int s=read(),t=read();PathAdd(s,t,s);PathAdd(s,t,t);}
	solve(1); printf("%lld
",ans>>1);
	return 0;
}
原文地址:https://www.cnblogs.com/tommy0103/p/13832510.html