BZOJ3772精神污染

BZOJ3772精神污染

题面:权限题,去网上找题面吧。

解析

有两种思考方式:1.考虑每条路径分别被多少条路径覆盖。2.考虑每条路径分别覆盖了多少条路径。两种都简单的说一下吧。
1.可以发现覆盖路径(a,b)的路径两端必然在以a为根的子树和以b为根的子树即dfs序上连续的一段,否则也可以转化为两端。这样就可以用主席树,分别对每一条路径,在ai对应的线段树中插入bi,然后就在a这一端的线段树上,区间查询b那一端对应的范围有多少个点,再反过来做一次,就得到了答案。这是我口胡的,可能有误,因为博主用的是思路2。
2.还是像1那样做用主席树,分别对每一条路径,在ai对应的线段树中插入bi,但我们这次统计的是a到b路径上分别对应了多少个节点,我们用dfs入栈出栈序,入栈+1,出栈-1,这样对每一颗线段树,它所统计的范围就是根节点到它自己这条链上的节点个数,那我们就可以用lca,答案就是f(a)+f(b)-f(lca)-f(fa_lca),注意答案要减去1(因为统计了自己)。

代码

博主太菜了,写代码时并没有用上面的思路,复杂度要多一个log。


#include<cstdio>
#include<vector>
#include<iostream>
#define N 100005
#define mid ((l+r)>>1)
#define LL long long
using namespace std;
const int __=3e6;
inline int In(){
	char c=getchar(); int x=0,ft=1;
	for(;c<'0'||c>'9';c=getchar()) if(c=='-') ft=-1;
	for(;c>='0'&&c<='9';c=getchar()) x=x*10+c-'0';
	return x*ft;
}
inline LL gcd(LL a,LL b){
	return (b==0)?a:gcd(b,a%b);
}
int n,m,s[N],t[N],h[N],e_tot=0;
LL ans=0,qou;
struct E{ int to,nex; }e[N<<1];
inline void add(int u,int v){
	e[++e_tot]=(E){v,h[u]}; h[u]=e_tot;
}
int d[N],fa[N],sz[N],son[N],top[N];
void dfs1(int u,int pre,int dep){
	d[u]=dep; fa[u]=pre; sz[u]=1;
	for(int i=h[u],v;i;i=e[i].nex){
		v=e[i].to; if(v==fa[u]) continue;
		dfs1(v,u,dep+1); sz[u]+=sz[v];
		if(!son[u]||sz[son[u]]<sz[v]) son[u]=v;
	}
}
int dfn[N],dfs_clock=0;
void dfs2(int u,int pre){
	top[u]=pre; dfn[u]=++dfs_clock;
	if(son[u]) dfs2(son[u],pre);
	for(int i=h[u],v;i;i=e[i].nex){
		v=e[i].to; if(v!=son[u]&&v!=fa[u]) dfs2(v,v);
	}
}
inline int LCA(int x,int y){
	while(top[x]!=top[y]){
		if(d[top[x]]<d[top[y]]) swap(x,y);
		x=fa[top[x]];
	}
	return d[x]>d[y]?y:x;
}
int rt[N],T_tot=0;
int c[__][2],sum[__];
inline int newnode(int u){
	++T_tot;
	sum[T_tot]=sum[u];
	c[T_tot][0]=c[u][0];
	c[T_tot][1]=c[u][1];
	return T_tot;
}
void Add(int G,int l,int r,int v,int& u){
	if(u==v) u=newnode(v); ++sum[u];
	if(l==r) return;
	if(G<=mid) Add(G,l,mid,c[v][0],c[u][0]);
	else Add(G,mid+1,r,c[v][1],c[u][1]);
}
vector<int> G[N];
void dfs(int u,int pre){
	rt[u]=rt[pre];
	for(int i=0;i<G[u].size();++i)
	Add(dfn[G[u][i]],1,n,rt[pre],rt[u]);
	for(int i=h[u],v;i;i=e[i].nex){
		v=e[i].to; if(v!=fa[u]) dfs(v,u);
	}
}
int Query(int L,int R,int l,int r,int w,int x,int y,int z){
	if(L<=l&&r<=R) return sum[y]+sum[z]-sum[w]-sum[x];
	if(R<=mid) return Query(L,R,l,mid,c[w][0],c[x][0],c[y][0],c[z][0]);
	if(L>mid) return Query(L,R,mid+1,r,c[w][1],c[x][1],c[y][1],c[z][1]);
	return Query(L,R,l,mid,c[w][0],c[x][0],c[y][0],c[z][0])+
	Query(L,R,mid+1,r,c[w][1],c[x][1],c[y][1],c[z][1]);
}
inline int Query(int x,int y){
	int lca=LCA(x,y),u=x,v=y,res=0;
	while(top[x]!=top[y]){
		if(d[top[x]]<d[top[y]]) swap(x,y);
		res+=Query(dfn[top[x]],dfn[x],1,n,rt[lca],rt[fa[lca]],rt[u],rt[v]);
		x=fa[top[x]];
	}
	if(d[x]<d[y]) swap(x,y);
	res+=Query(dfn[y],dfn[x],1,n,rt[lca],rt[fa[lca]],rt[u],rt[v]);
	return res;
}
int main(){
	n=In(); m=In(); qou=1ll*m*(m-1)/2;
	for(int i=1,u,v;i<n;++i){
		u=In(); v=In();
		add(u,v); add(v,u);
	}
	for(int i=1;i<=m;++i){
		s[i]=In(); t[i]=In();
		G[s[i]].push_back(t[i]);
	}
	dfs1(1,0,0); dfs2(1,1); dfs(1,0);
	for(int i=1;i<=m;++i) ans+=Query(s[i],t[i]);
	ans-=m; LL d=gcd(ans,qou); ans/=d; qou/=d;
	printf("%lld/%lld
",ans,qou);
	return 0;
}


原文地址:https://www.cnblogs.com/pkh68/p/10554972.html