洛谷 P5327 [ZJOI2019]语言

洛谷 P5327 [ZJOI2019]语言

https://www.luogu.com.cn/problem/P5327

Snipaste_2020-07-02_19-25-05.png

Snipaste_2020-07-02_19-24-50.png

Snipaste_2020-07-02_19-25-15.png

Tutorial

https://www.luogu.com.cn/blog/Sooke/solution-p5327

考虑如果 (n,m le 5 imes 10^3) 怎么做.

对于一个点 (u) ,如果我们将所有经过它的 (s,t) 点拿出来,发现它所可以到达的区域实际就是这些点的虚树的大小.

虚树的大小可以在dfs序上用线段树维护,默认 (1) 节点在虚树中,每个区间维护区间内虚树大小,dfs序最小(mn)和最大的节点(mx),合并的时候减去左边的(mx)和右边的(mn)的lca深度,计算答案时减去根节点(mn,mx)的lca深度即可.

考虑(n,mle10^5)的时候,我们可以用线段树合并来维护每个节点的虚树,将所有路径在树上差分一下即可.

Code

#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>
#define debug(...) fprintf(stderr,__VA_ARGS__)
using namespace std;
inline char gc() {
//	return getchar();
	static char buf[100000],*l=buf,*r=buf;
	return l==r&&(r=(l=buf)+fread(buf,1,100000,stdin),l==r)?EOF:*l++;
}
template<class T> void rd(T &x) {
	x=0; int f=1,ch=gc();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=gc();}
	while(ch>='0'&&ch<='9'){x=x*10-'0'+ch;ch=gc();}
	x*=f; 
}
typedef long long ll;
const int maxn=1e5+50;
const int maxdfc=maxn<<1;
int n,m; ll an;
int head[maxn];
int dfc,dfn[maxn],dep[maxn],a[20][maxdfc];;
vector<int> ad[maxn],de[maxn];
int root[maxn];
struct edge {
	int to,nex;
	edge(int to=0,int nex=0):to(to),nex(nex){}
};
vector<edge> G;
inline void addedge(int u,int v) {
	G.push_back(edge(v,head[u])),head[u]=G.size()-1;
	G.push_back(edge(u,head[v])),head[v]=G.size()-1;
}
namespace rmq {
	int bit[20],lg2[maxdfc];
	inline int cmp(int a,int b) {return dep[a]<dep[b]?a:b;}
	void dfs(int u,int fa) {
		a[0][dfn[u]=++dfc]=u;
		for(int i=head[u];~i;i=G[i].nex) {
			int v=G[i].to; if(v==fa) continue;
			dep[v]=dep[u]+1;
			dfs(v,u);
			a[0][++dfc]=u;
		}
	}
	void init() {
		dfs(1,0);
		bit[0]=1;
		for(int i=1;i<20;++i) bit[i]=bit[i-1]<<1;
		lg2[0]=-1;
		for(int i=1;i<=dfc;++i) lg2[i]=lg2[i>>1]+1;
		for(int k=1;bit[k]<=dfc;++k) {
			for(int i=1;i+bit[k]-1<=dfc;++i) {
				a[k][i]=cmp(a[k-1][i],a[k-1][i+bit[k-1]]);
			}
		}
	}
	inline int query(int l,int r) {
		int k=lg2[r-l+1];
		return cmp(a[k][l],a[k][r-bit[k]+1]);
	}
	inline int lca(int u,int v) {
		if(dfn[u]>dfn[v]) swap(u,v);
		return query(dfn[u],dfn[v]);
	}
}
namespace seg {
	const int maxnode=maxn*100;
	int ncnt;
	struct node {
		int ls,rs,cnt,mn,mx,val;
		node() {mn=mx=-1;}
		void doit(node other) {
			cnt=other.cnt,mn=other.mn,mx=other.mx,val=other.val;
		}
	} tree[maxnode];
	inline void pushup(int u) {
		int ls=tree[u].ls,rs=tree[u].rs;
		if(tree[ls].mn==-1) {tree[u].doit(tree[rs]); return;}
		if(tree[rs].mn==-1) {tree[u].doit(tree[ls]); return;}
		tree[u].mn=tree[ls].mn,tree[u].mx=tree[rs].mx;
		tree[u].val=tree[ls].val+tree[rs].val-dep[rmq::lca(tree[ls].mx,tree[rs].mn)];
	}
	void update(int &u,int l,int r,int qp,int qv) {
		if(!u) u=++ncnt;
		if(l==r) {
			tree[u].cnt+=qv;
			if(tree[u].cnt==0) tree[u]=node();
			else {
				tree[u].mn=tree[u].mx=a[0][qp];
				tree[u].val=dep[a[0][qp]];
			}
			return;
		}
		int mid=(l+r)>>1;
		if(qp<=mid) update(tree[u].ls,l,mid,qp,qv);
		else update(tree[u].rs,mid+1,r,qp,qv);
		pushup(u);
	}
	void merge(int &u,int v,int l,int r) {
		if(u==0||v==0) {u=u+v; return;}
		if(l==r) {
			tree[u].cnt+=tree[v].cnt;
			if(tree[u].cnt==0) tree[u]=node();
			else {
				tree[u].mn=tree[u].mx=a[0][l];
				tree[u].val=dep[a[0][l]];
			}
			return;
		}
		int mid=(l+r)>>1;
		merge(tree[u].ls,tree[v].ls,l,mid);
		merge(tree[u].rs,tree[v].rs,mid+1,r);
		pushup(u);
	}
	inline int sol(int u) {
		if(tree[u].mn==-1) return 0;
		return tree[u].val-dep[rmq::lca(tree[u].mn,tree[u].mx)]+1;
	}
}
void dfs(int u,int fa) {
	for(int i=0;i<ad[u].size();++i) {
		seg::update(root[u],1,dfc,dfn[ad[u][i]],1);
	}
	for(int i=head[u];~i;i=G[i].nex) {
		int v=G[i].to; if(v==fa) continue;
		dfs(v,u);
		seg::merge(root[u],root[v],1,dfc);
	}
	an+=seg::sol(root[u]);
	for(int i=0;i<de[u].size();++i) {
		seg::update(root[u],1,dfc,dfn[de[u][i]],-1);
	}
}
int main() {
	rd(n),rd(m);
	memset(head,-1,sizeof(head));
	for(int i=1;i<n;++i) {
		int u,v; rd(u),rd(v);
		addedge(u,v);
	}
	rmq::init();
	for(int i=1;i<=n;++i) {
		ad[i].push_back(i);
		de[i].push_back(i);
	}
	for(int i=1;i<=m;++i) {
		int s,t,w; rd(s),rd(t),w=rmq::lca(s,t);
		ad[s].push_back(s),ad[s].push_back(t);
		ad[t].push_back(s),ad[t].push_back(t);
		de[w].push_back(s),de[w].push_back(s);
		de[w].push_back(t),de[w].push_back(t);
	}
	dfs(1,0);
	an=(an-n)/2;
	printf("%d
",an);
	return 0;
}
原文地址:https://www.cnblogs.com/ljzalc1022/p/13226621.html