[HNOI2015]开店

VII.[HNOI2015]开店

首先,第一种方法便是动态点分治。

我们先考虑忽略年龄限制的情形。

我们考虑正常求一个点到另一个点的距离应该怎么求——

一般来说,我们会用\(dis(i,j)=dep_i+dep_j-2*dep_{lca(i,j)}\)对吧?

这个东西相当于将路径划分成两个部分,其中每个部分的长度都易于求出。上面我们采取了\(lca(i,j)\)作为分割点。

那如果我们在点分树上求路径长度,又该如何呢?

我们或许还是能自然地想到\(dep_i+dep_j-2*dep_{lca(i,j)}\),其中\(dep_i\)\(i\)到点分树的根的距离,而\(lca\)是点分树上的最近公共祖先。

但是很明显,这是错的——点分树上的父子关系极松,这意味着不一定有\(dis(i,j)=dep_j-dep_i\),其中\(i\)\(j\)的祖先。

那怎么办呢?

这父子关系再松,有一点也是满足的——两点在点分树上的lca,一定在原树上两点间路径上

换言之,必有

\(dis(i,j)=dis\Big(i,lca(i,j)\Big)+dis\Big(j,lca(i,j)\Big)\),其中\(lca(i,j)\)为点分树上lca,而\(dis\)为原树上距离。

原树上的距离,我们可以直接用ST表在\(O(1)\)时间内求出。故这个东西可以转到点分树上求出。

我们要求

\[\sum\limits_{i=1}^{n}dis(i,x) \]

\[\sum\limits_{i=1}^{n}dis\Big(i,lca(i,x)\Big)+dis\Big(x,lca(i,x)\Big) \]

如果我们换成枚举\(lca(i,x)\),则有

\[\sum\limits_{i\text{是}x\text{的祖先}}dis(x,i)*cnt_{x,i}+sum_{x,i} \]

注意到我们这里出现了两个东西:\(cnt\)\(sum\)。其中,\(cnt\)意为子树中所有 \(lca(x,j)=i\)\(j\)的数量,而\(sum\)意为子树中所有 \(lca(x,j)=i\)\(dis(i,j)\)之和。

显然,如果\(lca(x,j)=i\),它们只要满足来自\(i\)在点分树中的不同子树即可。即:\(i\)在点分树上的子树,挖去\(x\)所在的那颗子树即可。

我们可以预处理出\(i\)的所有儿子的数量,以及它们到\(i\)的距离和。我们用一个\(vecsf\)维护。然后,在每个节点处,维护子树中所有节点数量以及它们到父亲的距离和(方便在父亲处减掉这些东西),记为\(vecfa\)

则我们只需要不断跳父亲,然后从\(vecsf\)\(vecfa\)中对应加减即可。

现在有了年龄限制,那有如何呢?

好办。之前的\(vecsf\)\(vecfa\)可以只用一个值表示即可,我们现在把它换成vector。在vector内部按照颜色排序,并作前缀和。最终只要在vector中二分即可回答询问。

但因为实现的不好,它MLE了。

代码:

#pragma GCC optimize(3)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
int n,m,lim,val[100100],dep[100100],mn[200100][20],in[100100],LG[200100],tot,fa[100100],admin;
ll las;
namespace Tree{
	int sz[100100],SZ,msz[100100],ROOT,head[100100],cnt;
	struct node{
		int to,next,val;
	}edge[200100];
	void ae(int u,int v,int w){
		edge[cnt].next=head[u],edge[cnt].to=v,edge[cnt].val=w,head[u]=cnt++;
		edge[cnt].next=head[v],edge[cnt].to=u,edge[cnt].val=w,head[v]=cnt++;
	}
	bool vis[100100];
	void getsz(int x,int fa){
		sz[x]=1;
		for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to]&&edge[i].to!=fa)getsz(edge[i].to,x),sz[x]+=sz[edge[i].to];
	}
	void getroot(int x,int fa){
		sz[x]=1,msz[x]=0;
		for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to]&&edge[i].to!=fa)getroot(edge[i].to,x),sz[x]+=sz[edge[i].to],msz[x]=max(msz[x],sz[edge[i].to]);
		msz[x]=max(msz[x],SZ-sz[x]);
		if(msz[x]<msz[ROOT])ROOT=x;
	}
	void solve(int x){
		getsz(x,0); 
		vis[x]=true;
		for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to])ROOT=0,SZ=sz[edge[i].to],getroot(edge[i].to,0),fa[ROOT]=x,solve(ROOT);
	}
	void getural(int x,int fa){
		mn[++tot][0]=x,in[x]=tot;
		for(int i=head[x];i!=-1;i=edge[i].next)if(edge[i].to!=fa)dep[edge[i].to]=dep[x]+edge[i].val,getural(edge[i].to,x),mn[++tot][0]=x;
	}
}
int MIN(int i,int j){
	return dep[i]<dep[j]?i:j;
}
int LCA(int i,int j){
	i=in[i],j=in[j];
	if(i>j)swap(i,j);
	int k=LG[j-i+1];
	return MIN(mn[i][k],mn[j-(1<<k)+1][k]);
}
int DIS(int i,int j){
	return dep[i]+dep[j]-dep[LCA(i,j)]*2;
}
namespace cdt{
	vector<int>v[100100];
	vector<pair<int,ll> >vecfa[100100],vecsf[100100];
	void prepvec(int x,int z){
		if(fa[z])vecfa[z].push_back(make_pair(val[x],DIS(x,fa[z])));
		vecsf[z].push_back(make_pair(val[x],DIS(x,z)));
		for(auto y:v[x])prepvec(y,z);
	}
	ll calc(int x,int L,int R){
		ll res=0;
		int u=x;
		while(x){
			int l=lower_bound(vecsf[x].begin(),vecsf[x].end(),make_pair(L,-1ll))-vecsf[x].begin()-1;
			int r=upper_bound(vecsf[x].begin(),vecsf[x].end(),make_pair(R,0x3f3f3f3f3f3f3f3fll))-vecsf[x].begin()-1;
			res+=vecsf[x][r].second-vecsf[x][l].second;
			res+=1ll*DIS(u,x)*(r-l);
			if(!fa[x])break;
			l=lower_bound(vecfa[x].begin(),vecfa[x].end(),make_pair(L,-1ll))-vecfa[x].begin()-1;
			r=upper_bound(vecfa[x].begin(),vecfa[x].end(),make_pair(R,0x3f3f3f3f3f3f3f3fll))-vecfa[x].begin()-1;
			res-=vecfa[x][r].second-vecfa[x][l].second;
			res-=1ll*DIS(u,fa[x])*(r-l);
			x=fa[x];
		}
		return res;
	}
	void prepare(){
		for(int i=1;i<=n;i++)if(fa[i])v[fa[i]].push_back(i);
		for(int i=1;i<=n;i++){
			prepvec(i,i);
			vecfa[i].push_back(make_pair(-1,0)),vecsf[i].push_back(make_pair(-1,0));
			sort(vecfa[i].begin(),vecfa[i].end());
			sort(vecsf[i].begin(),vecsf[i].end());
			for(int j=1;j<vecfa[i].size();j++)vecfa[i][j].second+=vecfa[i][j-1].second;
			for(int j=1;j<vecsf[i].size();j++)vecsf[i][j].second+=vecsf[i][j-1].second;
		}
	}
}
void read(int &x){
	x=0;
	char c=getchar();
	while(c>'9'||c<'0')c=getchar();
	while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=getchar();
}
int main(){
	read(n),read(m),read(lim),memset(Tree::head,-1,sizeof(Tree::head));
	for(int i=1;i<=n;i++)read(val[i]);
	for(int i=1,x,y,z;i<n;i++)read(x),read(y),read(z),Tree::ae(x,y,z);
	Tree::msz[0]=n+1,Tree::SZ=n,Tree::getroot(1,0),admin=Tree::ROOT,Tree::solve(Tree::ROOT);
	Tree::getural(1,0);
	for(int i=2;i<=tot;i++)LG[i]=LG[i>>1]+1;
	for(int j=1;j<=LG[tot];j++)for(int i=1;i+(1<<j)-1<=tot;i++)mn[i][j]=MIN(mn[i][j-1],mn[i+(1<<(j-1))][j-1]);
	cdt::prepare();
	for(int i=1,x,l,r;i<=m;i++){
		read(x),read(l),read(r),l=(las+l)%lim,r=(las+r)%lim;
		if(l>r)swap(l,r);
		printf("%lld\n",las=cdt::calc(x,l,r));
	}
	return 0;
} 

然后,第二种方法便是主席树+树剖。

仍然先忽略年龄限制,我们直接使用\(dis(i,j)=dep_i+dep_j-2*dep_{lca(i,j)}\),得到

\[\sum\limits_{i=1}^{n}dep_i+dep_x-2*dep_{lca(i,x)} \]

化简得

\[n*dep_x+\sum\limits_{i=1}^{n}dep_i-2*\sum\limits_{i=1}^{n}dep_{lca(i,x)} \]

前两个可以很轻松预处理出来,但是最后一部分呢?

考虑我们每个节点向上走一直走到根,在每条边上维护一个计数器,然后每次被经过了的边上的计数器就加\(1\)。然后询问的时候,从询问的点出发向上走,对于每条经过的边,答案加上(计数器大小*边权)即可。

这个很好理解,因为\(dep_{lca(i,x)}\),就等于所有既是\(i\)的祖先,又是\(x\)的祖先的边的边权和。刚才的操作的意义,就是求这些边的边权和。

如果我们用树剖的话,复杂度就是\(O(n\log^2 n)\)的。

但是,问题来了,加上边权限制怎么办?

回忆一下,在之前动态点分治的方法中,我们用前缀和实现了这一效果;而在这里,我们也可以用前缀和——只不过是对树剖时建的线段树作前缀和。线段树的前缀和,就是主席树咯。

具体而言,我们在建树的时候,将所有节点按照年龄排序后插入,这样询问时就可以直接在对应的主席树上减以下即可。

代码:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll; 
#define mid ((l+r)>>1)
int n,m,lim,a[150010],dis[150010],fa[150010],son[150010],dfn[150010],rev[150010],sz[150010],top[150010],head[150010],tot,cnt,cpt,rt[150010],num[150010],val[150010];
ll sum[150010],ans,add[150010];
struct Edge{
	int to,next,val;
}edge[300100];
void ae(int u,int v,int w){
	edge[cnt].next=head[u],edge[cnt].to=v,edge[cnt].val=w,head[u]=cnt++;
	edge[cnt].next=head[v],edge[cnt].to=u,edge[cnt].val=w,head[v]=cnt++;
}
vector<int>ds,v[150010];
void dfs1(int x){
	sz[x]=1;
	for(int i=head[x],y;i!=-1;i=edge[i].next){
		if((y=edge[i].to)==fa[x])continue;
		fa[y]=x,dis[y]=dis[x]+edge[i].val,val[y]=edge[i].val;
		dfs1(y);
		sz[x]+=sz[y];
		if(sz[y]>sz[son[x]])son[x]=y;
	}
}
void dfs2(int x){
	if(son[x])dfn[++tot]=son[x],rev[son[x]]=tot,top[son[x]]=top[x],dfs2(son[x]);
	for(int i=head[x],y;i!=-1;i=edge[i].next){
		y=edge[i].to;
		if(y==fa[x]||y==son[x])continue;
		dfn[++tot]=y,rev[y]=tot,top[y]=y,dfs2(edge[i].to);
	}
}
struct SegTree{
	int lson,rson,era,tag;
	ll sum;
}seg[10001000];
void build(int &x,int l,int r){
	x=++cpt;
	if(l==r)return;
	build(seg[x].lson,l,mid),build(seg[x].rson,mid+1,r);
}
void pushup(int x,int l,int r){
	seg[x].sum=seg[seg[x].lson].sum+(sum[mid]-sum[l-1])*seg[seg[x].lson].tag+seg[seg[x].rson].sum+(sum[r]-sum[mid])*seg[seg[x].rson].tag;
}
void modify(int pre,int &x,int l,int r,int L,int R,int tim){
	if(l>R||r<L)return;
	if(seg[x].era!=tim)x=++cpt,seg[x]=seg[pre],seg[x].era=tim;
//	printf("%d:(%d,%d):(%d,%d):%d\n",x,l,r,L,R,tim);
	if(L<=l&&r<=R){seg[x].tag++;return;}
	modify(seg[pre].lson,seg[x].lson,l,mid,L,R,tim),modify(seg[pre].rson,seg[x].rson,mid+1,r,L,R,tim),pushup(x,l,r);
}
ll query(int x,int l,int r,int L,int R,int tag){
	if(!x||l>R||r<L)return 0;
	tag+=seg[x].tag;
//	printf("%d:(%d,%d):(%d,%d):%d\n",x,l,r,L,R,tag);
	if(L<=l&&r<=R)return (sum[r]-sum[l-1])*tag+seg[x].sum;
	return query(seg[x].lson,l,mid,L,R,tag)+query(seg[x].rson,mid+1,r,L,R,tag);
}
void initjump(int x){
	int col=a[x];
	while(x){
		modify(rt[col-1],rt[col],1,n,rev[top[x]],rev[x],col);
		x=fa[top[x]];
	}
}
ll queryjump(int x,int L,int R){
	ll res=0;
	while(x){
		res+=query(rt[R],1,n,rev[top[x]],rev[x],0)-query(rt[L-1],1,n,rev[top[x]],rev[x],0);
		x=fa[top[x]];
	}
//	printf("%d\n",res);
	return res;
}
int main(){
	scanf("%d%d%d",&n,&m,&lim),memset(head,-1,sizeof(head));
	for(int i=1;i<=n;i++)scanf("%d",&a[i]),ds.push_back(a[i]);
	for(int i=1,x,y,z;i<n;i++)scanf("%d%d%d",&x,&y,&z),ae(x,y,z);
	dfs1(1),top[1]=rev[1]=dfn[1]=tot=1,dfs2(1);
//	for(int x=1;x<=n;x++)printf("%d::FA:%d SN:%d SZ:%d DN:%d RV:%d DS:%d TP:%d\n",x,fa[x],son[x],sz[x],dfn[x],rev[x],dis[x],top[x]);
	sort(ds.begin(),ds.end()),ds.resize(unique(ds.begin(),ds.end())-ds.begin());
	for(int i=1;i<=n;i++)a[i]=lower_bound(ds.begin(),ds.end(),a[i])-ds.begin()+1,v[a[i]].push_back(i),add[a[i]]+=dis[i],num[a[i]]++;
//	for(int i=1;i<=n;i++)printf("(%d:%d)",i,a[i]);puts("");
	for(int i=1;i<=n;i++)sum[i]=sum[i-1]+val[dfn[i]],add[i]+=add[i-1],num[i]+=num[i-1];
	build(rt[0],1,n);
	for(int i=1;i<=ds.size();i++){
		rt[i]=++cpt,seg[rt[i]]=seg[rt[i-1]],seg[rt[i]].era=i;
		for(auto x:v[i])initjump(x);
	}
	for(int i=1,x,l,r;i<=m;i++){
		scanf("%d%d%d",&x,&l,&r);
		l=(ans+l)%lim,r=(ans+r)%lim;
		if(l>r)swap(l,r);
		l=lower_bound(ds.begin(),ds.end(),l)-ds.begin()+1;
		r=upper_bound(ds.begin(),ds.end(),r)-ds.begin();
//		printf("%d %d %d\n",x,l,r);
		if(l>r){printf("%lld\n",ans=0);continue;}
//		printf("%d,%d\n",add[r]-add[l-1],num[r]-num[l-1]);
		ans=(add[r]-add[l-1])+1ll*dis[x]*(num[r]-num[l-1])-2ll*queryjump(x,l,r);
		printf("%lld\n",ans);
	}
	return 0;
}

原文地址:https://www.cnblogs.com/Troverld/p/14605805.html