NOI.AC#2266Bacteria【根号分治,倍增】

正题

题目链接:http://noi.ac/problem/2266


题目大意

给出\(n\)个点的一棵树,有一些边上有中转站(边长度为\(2\),中间有一个中转站),否则就是边长为\(1\)

\(m\)次询问一个东西从\(x\)出发走到\(y\),每隔\(k\)步中转站会关闭一次(\(k\)的倍数步走完后不能在中转站上)。求在关闭多少次以内可以到达

\(1\leq n,m\leq 10^5\)


解题思路

发现最多只需要走\(2n\)步,然后每隔\(k\)步关闭一次,所以可以考虑根号分治。

先处理好总的倍增数组,后面求\(LCA\)和跳链要用。

对于\(k=1\)的询问,就看一下中间有没有中转站,如果有就是\(-1\)否则就是距离

对于\(k\leq \sqrt n\)的询问,我们对于每个\(k\)都进行一次预处理,处理每个周期每个点往上走能走到哪里。然后再处理一个倍增数组,然后询问的时候就在上面跳就好了

对于\(k>\sqrt n\)的询问直接每次暴力跳\(k\)步如果是中转站就跳\(k-1\)步,然后一直跳到\(LCA\)

时间复杂度\(O(n\sqrt n\log n)\),调一下块的大小就能过了


code

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
const int N=2e5+10,T=17;
struct edge{
	int to,next;
}a[N<<1];
struct node{
	int x,y,k,id;
}q[N];
int n,m,Q,tot,num,ans[N],ls[N],dep[N],sd[N];
int g[N][100],f[N][T+1],h[N][T+1];
void addl(int x,int y){
	a[++tot].to=y;
	a[tot].next=ls[x];
	ls[x]=tot;return;
}
bool cmp(node x,node y)
{return x.k<y.k;}
void dfs(int x,int fa){
	g[x][0]=x;sd[x]=sd[fa]+(x>n);
	f[x][0]=fa;dep[x]=dep[fa]+1;
	for(int i=1;i<=Q;i++)g[x][i]=g[fa][i-1];
	for(int i=ls[x];i;i=a[i].next){
		int y=a[i].to;
		if(y==fa)continue;
		dfs(y,x);
	}
	return;
}
int LCA(int x,int y){
	if(dep[x]>dep[y])swap(x,y);
	for(int i=T;i>=0;i--)
		if(dep[f[y][i]]>=dep[x])y=f[y][i];
	if(x==y)return x;
	for(int i=T;i>=0;i--)
		if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];
	return f[x][0];
}
void calc(int x,int fa,int k){
	if(g[x][k]>n)h[x][0]=g[x][k-1];
	else h[x][0]=g[x][k];
	for(int i=ls[x];i;i=a[i].next){
		int y=a[i].to;
		if(y==fa)continue;
		calc(y,x,k);
	}
	return;
}
int query(int x,int y,int k){
	int p=LCA(x,y),ans=0;
	for(int i=T;i>=0;i--){
		if(dep[h[x][i]]>dep[p])x=h[x][i],ans+=(1<<i);
		if(dep[h[y][i]]>dep[p])y=h[y][i],ans+=(1<<i);
	}
	if(x!=y){
		int dis=dep[x]+dep[y]-2*dep[p];
		if(dis>=0&&dis<=k)ans++;
		else if(dis>k) ans+=2;
	}
	return ans;
}
int getf(int x,int k){
	for(int i=0;i<=T;i++)
		if((k>>i)&1)x=f[x][i];
	return x;
}
int solve(int x,int y,int k){
	int p=LCA(x,y),ans=0;
	while(dep[x]>dep[p]){
		int z=getf(x,k-1),t;
		if(f[z][0]>n)t=z;
		else t=f[z][0];
		if(dep[t]>dep[p])x=t,ans++;
		else break;
	}
	while(dep[y]>dep[p]){
		int z=getf(y,k-1),t;
		if(f[z][0]>n)t=z;
		else t=f[z][0];
		if(dep[t]>dep[p])y=t,ans++;
		else break;
	}
	if(x!=y){
		int dis=dep[x]+dep[y]-2*dep[p];
		if(dis>=0&&dis<=k)ans++;
		else if(dis>k) ans+=2;
	}
	return ans;
}
int main()
{
	scanf("%d",&n);num=n;
	for(int i=1;i<n;i++){
		int x,y,w;
		scanf("%d%d%d",&x,&y,&w);
		if(w==1)addl(x,y),addl(y,x);
		else{
			++num;
			addl(x,num);addl(num,y);
			addl(y,num);addl(num,x);
		}
	}
	Q=sqrt(n);
	if(Q>=70)Q=70;
	scanf("%d",&m);
	for(int i=1;i<=m;i++){
		scanf("%d%d%d",&q[i].x,&q[i].y,&q[i].k);
		q[i].id=i;
	}
	sort(q+1,q+1+m,cmp);
	dfs(1,0);
	for(int j=1;j<=T;j++)
		for(int i=1;i<=num;i++)
			f[i][j]=f[f[i][j-1]][j-1];
	int l=1,r=1;
	for(;r<=m&&q[r].k<=Q;r++,l=r){
		while(r<m&&q[r].k==q[r+1].k)r++;
		if(q[r].k==1){
			for(int i=l;i<=r;i++){
				int x=q[i].x,y=q[i].y,lca=LCA(x,y);
				if(sd[x]+sd[y]-2*sd[lca])ans[q[i].id]=-1;
				else ans[q[i].id]=dep[x]+dep[y]-2*dep[lca];
			}
			continue;
		}
		calc(1,1,q[r].k);
		for(int j=1;j<=T;j++)
			for(int i=1;i<=num;i++)
				h[i][j]=h[h[i][j-1]][j-1];
		for(int i=l;i<=r;i++)
			ans[q[i].id]=query(q[i].x,q[i].y,q[i].k);
	}
	for(int i=r;i<=m;i++)
		ans[q[i].id]=solve(q[i].x,q[i].y,q[i].k);
	for(int i=1;i<=m;i++)
		printf("%d\n",ans[i]);
	return 0;
}
原文地址:https://www.cnblogs.com/QuantAsk/p/14590865.html