[BJOI2017]树的难题

IX.[BJOI2017]树的难题

debug三天,精神崩溃

论一行if(vis[v[x][r].second]){r++;continue;}忘记加上后所有代码全都莫名其妙TLE且查不出锅的痛苦

首先,我们考虑常规淀粉质。

我们考虑一条路径,它会被(淀粉质的分治根)截成两段。如果我们对于分治树中的每一个节点,预处理出来它到树根的路径权值,记为\(sum_x\),则一条完整路径的权值则为\(sum_x+sum_y\)

稍等,我们好像忘记了一种情况——如果这两条路径顶端的边的颜色相同怎么办

换句话说,假如某一半路径的颜色段为(顺序为从根节点往下)ABABC,另一半为ABAC,两半拼一起,我们得到CBABAABAC。显然,这个A就被算了两次,应该被减掉。

因此,这种情况的权值则为\(sum_x+sum_y-val_c\),其中\(val_c\)\(c\)颜色的权值,而\(c\)为两条路径顶端的颜色。

很明显这两者要分开考虑。

然后就是求值了。显然,对于半条路径,与它可以拼成完整路径的另一半的长度,是一个区间。所以,我们可以建一棵线段树,以深度为下标,存储所有当前深度的点中,\(sum_x\)的最大值。

我们需要两棵线段树——一棵用于同种颜色的储存,而另一棵用于全局颜色的储存。在一种颜色全部处理完之后,将当前颜色线段树与全局线段树合并,并清空当前线段树即可。

明显复杂度为\(O(n\log^2n)\)。由于人傻常数大,它TLE50,即使开O3也一样。

TLE的线段树代码:

#pragma GCC optimize(3)
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
typedef long long ll;
const ll fni=-1e18;
#define lson x<<1
#define rson x<<1|1
#define mid ((l+r)>>1)
struct SegTree{
	ll seg[800100];
	void init(){
		for(int i=1;i<=800000;i++)seg[i]=fni;
	}
	void pushup(int x){
		seg[x]=max(seg[lson],seg[rson]);
	}
	void modify(int x,int l,int r,int P,ll vl){
		if(l>P||r<P)return;
		if(l==r){seg[x]=max(seg[x],vl);return;} 
		modify(lson,l,mid,P,vl),modify(rson,mid+1,r,P,vl),pushup(x);
	}
	void setzero(int x,int l,int r){
		if(seg[x]==fni)return;
		seg[x]=fni;
		if(l!=r)setzero(lson,l,mid),setzero(rson,mid+1,r);
	}
	ll query(int x,int l,int r,int L,int R){
		if(l>R||r<L)return fni;
		if(L<=l&&r<=R)return seg[x];
		return max(query(lson,l,mid,L,R),query(rson,mid+1,r,L,R));
	}
}all,same;
void merge(int x,int l,int r){
	if(same.seg[x]==fni)return;
	all.seg[x]=max(all.seg[x],same.seg[x]);
	same.seg[x]=fni;
	if(l!=r)merge(lson,l,mid),merge(rson,mid+1,r);
}
int n,m,L,R,val[200100],ROOT,SZ,sz[200100],msz[200100];
ll mx=fni;
vector<pair<int,int> >v[200100];
bool vis[200100];
void getsz(int x,int fa){
	sz[x]=1;
	for(auto i:v[x])if(!vis[i.second]&&i.second!=fa)getsz(i.second,x),sz[x]+=sz[i.second];
}
void getroot(int x,int fa){
	sz[x]=1,msz[x]=0;
	for(auto i:v[x])if(!vis[i.second]&&i.second!=fa)getroot(i.second,x),sz[x]+=sz[i.second],msz[x]=max(msz[x],sz[i.second]);
	msz[x]=max(msz[x],SZ-sz[x]);
	if(msz[x]<msz[ROOT])ROOT=x;
}
void write(int x,int fa,SegTree &sg,int las,int dep,ll sum){
	if(dep>R)return;
	sg.modify(1,0,R,dep,sum);
	for(auto i:v[x])if(!vis[i.second]&&i.second!=fa)write(i.second,x,sg,i.first,dep+1,sum+(i.first==las?0:val[i.first]));
}
void read(int x,int fa,SegTree &sg,int las,int dep,ll sum){
	if(dep>R)return;
	mx=max(mx,sum+sg.query(1,0,R,L-dep,R-dep));
	for(auto i:v[x])if(!vis[i.second]&&i.second!=fa)read(i.second,x,sg,i.first,dep+1,sum+(i.first==las?0:val[i.first]));
}
void calc(int x){
	all.modify(1,0,R,0,0);
	for(int l=0,r=0;r<v[x].size();l=r){
		while(r<v[x].size()&&v[x][r].first==v[x][l].first){
			if(vis[v[x][r].second]){r++;continue;}
			int i=v[x][r].second,j=v[x][r].first;
			read(i,x,same,j,1,0);
			write(i,x,same,j,1,val[j]);
			read(i,x,all,j,1,val[j]);
			r++;
		}
		merge(1,0,R);
	}
	all.setzero(1,0,R);
}
void solve(int x){
	calc(x);
	getsz(x,0); 
	vis[x]=true;
	for(auto i:v[x])if(!vis[i.second])ROOT=0,SZ=sz[i.second],getroot(i.second,0),solve(ROOT);
}
void read(int &x){
	x=0;
	char c=getchar();
	int fl=1;
	while(c>'9'||c<'0')fl=(c=='-'?-fl:fl),c=getchar();
	while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=getchar();
	x*=fl;
}
int main(){
	read(n),read(m),read(L),read(R);
	for(int i=1;i<=m;i++)read(val[i]);
	for(int i=1,x,y,z;i<n;i++)read(x),read(y),read(z),v[x].push_back(make_pair(z,y)),v[y].push_back(make_pair(z,x));
	for(int i=1;i<=n;i++)sort(v[i].begin(),v[i].end());
	all.init(),same.init();
	msz[0]=n+1,SZ=n,getroot(1,0),solve(ROOT);
	printf("%lld\n",mx);
	return 0;
}

然后就是正解了——一个名叫单调队列按秩合并的trick。

显然,如果我们把所有半路径按照深度排序,它们合法的转移区间是单调递减的。

比如说,如果设深度为\(dep_x\)的话,则合法深度区间则为\([L-dep_x,R-dep_x]\)。当\(dep_x\nearrow\)时,整个区间\(\searrow\)

这不是经典老题滑动窗口吗?使用单调队列维护即可。

我们考虑颜色相同的情况。我们可以用来维护相同深度时的\(sum\)的最大值。对于每一棵子树,我们按照节点深度处理,在桶上跑滑动窗口。在整棵子树跑完后,用它们的值更新桶即可。

然后颜色不同的情况类似,只不过是对于每一种颜色一起处理,不需要关心具体从哪棵子树过来罢了。

稍等,这个算法是假的。很明显,这个算法的复杂度为桶的大小。该大小最大可以到直径,即\(n\)级别。如果开门见喜,一上来就遇到了直径,则之后每一次滑动窗口都要完整跑一遍直径。如果儿子数量很多的话,复杂度是会退化成\(O(n^2)\)的。

那怎么办呢?

对于颜色相同时,我们按照子树内最大深度递增的方式处理子树,先处理深度浅的子树,再处理深度深的。这就保证了单调队列的复杂度是严格\(O(\sum dep)\),即\(O(n)\)的。

在颜色不同时,我们仍然这样做,先处理深度浅的颜色,再处理深度深的。

这种trick,就是单调队列按秩合并——按照长度递增的顺序处理多条单调队列

最后还有一件事——排序。显然,如果直接排序,总复杂度是\(O(n\log^2n)\)的。当然,常数比之前小很多(因为每一层的排序复杂度都跑不满),因此可以过去。当然,因为深度范围小,使用桶排即可。

或者放弃dfs,使用bfs,毕竟bfs本身就保证了按照深度排序。

这里两种代码都能AC,\(n\log n\)的bfs或者\(n\log^2n\)的dfs。

bfs:

#pragma GCC optimize(3)
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<queue>
using namespace std;
int n,m,L,R,val[200100],ROOT,SZ,sz[200100],msz[200100],dep[200100],sum[200100],mdp[200100],cdp[200100],mx=0x80808080,Glo[200100],Loc[200100],glo,loc;
vector<pair<int,int> >v[200100];
bool vis[200100];
void getsz(int x,int fa){
	sz[x]=1;
	for(auto i:v[x])if(!vis[i.second]&&i.second!=fa)getsz(i.second,x),sz[x]+=sz[i.second];
}
void getroot(int x,int fa){
	sz[x]=1,msz[x]=0;
	for(auto i:v[x])if(!vis[i.second]&&i.second!=fa)getroot(i.second,x),sz[x]+=sz[i.second],msz[x]=max(msz[x],sz[i.second]);
	msz[x]=max(msz[x],SZ-sz[x]);
	if(msz[x]<msz[ROOT])ROOT=x;
}
void getdep(int x,int fa,int las){
	mdp[x]=dep[x]=dep[fa]+1;
//	printf("%d:%d %d\n",x,dep[x],sum[x]);
	for(auto i:v[x])if(!vis[i.second]&&i.second!=fa)sum[i.second]=sum[x]+(i.first==las?0:val[i.first]),getdep(i.second,x,i.first),mdp[x]=max(mdp[x],mdp[i.second]);
}
bool cmp(pair<int,int>x,pair<int,int>y){
	if(x.first==y.first)return mdp[x.second]<mdp[y.second];
	return cdp[x.first]==cdp[y.first]?x.first<y.first:cdp[x.first]<cdp[y.first];
}
deque<int>dq;
queue<int>q;
void bfswrite(int *arr,int &lim){
	while(!q.empty()){
		int x=q.front();q.pop();
		if(dep[x]>R)break;
		arr[dep[x]]=max(arr[dep[x]],sum[x]),lim=max(lim,dep[x]);
		for(auto i:v[x])if(dep[i.second]>dep[x]&&!vis[i.second])q.push(i.second);
	}
}
void bfsread(int *arr,int lim,int delta){
	while(!q.empty()){
		int x=q.front();q.pop();
		if(dep[x]>R)break;
		while(lim>=0&&lim+dep[x]>=L){
			while(!dq.empty()&&arr[dq.back()]<=arr[lim])dq.pop_back();
			dq.push_back(lim--);
		}
		while(!dq.empty()&&dq.front()+dep[x]>R)dq.pop_front();
		if(!dq.empty())mx=max(1ll*mx,0ll+arr[dq.front()]+sum[x]-delta);
		for(auto i:v[x])if(dep[i.second]>dep[x]&&!vis[i.second])q.push(i.second);
	}
	dq.clear();
}
void calc(int x){
//	printf("ROOT:%d:\n",x);
	dep[0]=-1,sum[x]=0;
	getdep(x,0,0);
	for(auto i:v[x])if(!vis[i.second])cdp[i.first]=max(cdp[i.first],mdp[i.second]);
	sort(v[x].begin(),v[x].end(),cmp);
	Glo[0]=0;
	for(int l=0,r=0;r<v[x].size();l=r){
		while(r<v[x].size()&&v[x][r].first==v[x][l].first){
			if(!vis[v[x][r].second])q.push(v[x][r].second),bfsread(Loc,loc,val[v[x][r].first]),q.push(v[x][r].second),bfswrite(Loc,loc);
			r++;
		}
		for(int k=l;k<r;k++)if(!vis[v[x][k].second])q.push(v[x][k].second);
		bfsread(Glo,glo,0);
		for(int k=0;k<=loc;k++)Glo[k]=max(Glo[k],Loc[k]),Loc[k]=0x80808080;
		glo=max(glo,loc),loc=0;
	}
	for(int k=0;k<=glo;k++)Glo[k]=0x80808080;
	glo=0;
	for(auto i:v[x])if(!vis[i.second])cdp[i.first]=0;
}
void solve(int x){
	calc(x);
	getsz(x,0); 
	vis[x]=true;
	for(auto i:v[x])if(!vis[i.second])ROOT=0,SZ=sz[i.second],getroot(i.second,0),solve(ROOT);
}
void read(int &x){
	x=0;
	char c=getchar();
	int fl=1;
	while(c>'9'||c<'0')fl=(c=='-'?-fl:fl),c=getchar();
	while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=getchar();
	x*=fl;
}
int main(){
	read(n),read(m),read(L),read(R),memset(Glo,0x80,sizeof(Glo)),memset(Loc,0x80,sizeof(Loc));
	for(int i=1;i<=m;i++)read(val[i]);
	for(int i=1,x,y,z;i<n;i++)read(x),read(y),read(z),v[x].push_back(make_pair(z,y)),v[y].push_back(make_pair(z,x));
	msz[0]=n+1,SZ=n,getroot(1,0),solve(ROOT);
	printf("%d\n",mx);
	return 0;
}

dfs:

#pragma GCC optimize(3)
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<queue>
using namespace std;
int n,m,L,R,head[200100],cnt,val[200100],mx=0x80808080;
int ROOT,SZ,sz[200100],msz[200100];
int cdp[200100],ddp[200100],buc[200100];
struct Edge{
	int to,next,val;
}edge[400100];
void ae(int u,int v,int w){
	edge[cnt].next=head[u],edge[cnt].to=v,edge[cnt].val=w,head[u]=cnt++;
	edge[cnt].next=head[v],edge[cnt].to=u,edge[cnt].val=w,head[v]=cnt++;
}
struct node{
	int dep,sum,col,frm;
	node(int A,int B,int C,int D){dep=A,sum=B,col=C,frm=D;}
};
bool cmp1(const node &x,const node &y){
	if(x.col!=y.col){
		if(cdp[x.col]!=cdp[y.col])return cdp[x.col]<cdp[y.col];
		return x.col<y.col;
	}
	if(x.frm!=y.frm){
		if(ddp[x.frm]!=ddp[y.frm])return ddp[x.frm]<ddp[y.frm];
		return x.frm<y.frm;
	}
	return x.dep<y.dep;
}
bool cmp2(const node &x,const node &y){
	return x.dep<y.dep;
}
vector<node>arr;
bool vis[200100];
void getsz(int x,int fa){
	sz[x]=1;
	for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to]&&edge[i].to!=fa)getsz(edge[i].to,x),sz[x]+=sz[edge[i].to];
}
void getroot(int x,int fa){
	sz[x]=1,msz[x]=0;
	for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to]&&edge[i].to!=fa)getroot(edge[i].to,x),sz[x]+=sz[edge[i].to],msz[x]=max(msz[x],sz[edge[i].to]);
	msz[x]=max(msz[x],SZ-sz[x]);
	if(msz[x]<msz[ROOT])ROOT=x;
}
void getdep(int x,int fa,int dep,int las,int sum,int col,int frm){
	if(dep>R)return;
	cdp[col]=max(cdp[col],dep);
	ddp[frm]=max(ddp[frm],dep);
	arr.push_back(node(dep,sum,col,frm));
	for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to]&&edge[i].to!=fa)getdep(edge[i].to,x,dep+1,edge[i].val,sum+(las==edge[i].val?0:val[edge[i].val]),col,frm);
}
deque<int>dq;
void calc(int x){
	arr.clear();
	for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to])getdep(edge[i].to,x,1,edge[i].val,val[edge[i].val],edge[i].val,edge[i].to);
	sort(arr.begin(),arr.end(),cmp1);
	for(int l=0,r=0,tmp=0,lim=0;r<arr.size();l=r){
		dq.clear(),tmp=lim;
		while(r<arr.size()&&arr[r].frm==arr[l].frm){
			while(tmp>=0&&tmp+arr[r].dep>=L){
				while(!dq.empty()&&buc[dq.back()]<=buc[tmp])dq.pop_back();
				dq.push_back(tmp--);
			}
			while(!dq.empty()&&dq.front()+arr[r].dep>R)dq.pop_front();
			if(!dq.empty())mx=max(1ll*mx,0ll+buc[dq.front()]+arr[r].sum-val[arr[r].col]);
			r++;
		}
		if(r==arr.size()||arr[r].col!=arr[l].col){
			for(int k=0;k<=lim;k++)buc[k]=0x80808080;
			lim=0;
		}else for(int k=l;k<r;k++)buc[arr[k].dep]=max(buc[arr[k].dep],arr[k].sum),lim=max(lim,arr[k].dep);
	}
	buc[0]=0;
	for(int l=0,r=0,tmp=0,lim=0;r<arr.size();l=r){
		dq.clear();tmp=lim;
		while(r<arr.size()&&arr[r].col==arr[l].col)r++;
		sort(arr.begin()+l,arr.begin()+r,cmp2);
		for(int k=l;k<r;k++){
			while(tmp>=0&&tmp+arr[k].dep>=L){
				while(!dq.empty()&&buc[dq.back()]<=buc[tmp])dq.pop_back();
				dq.push_back(tmp);
				tmp--;
			}
			while(!dq.empty()&&dq.front()+arr[k].dep>R)dq.pop_front();
			if(!dq.empty())mx=max(1ll*mx,0ll+buc[dq.front()]+arr[k].sum);
		}
		if(r==arr.size()){
			for(int k=0;k<=lim;k++)buc[k]=0x80808080;
			lim=0;
		}else for(int k=l;k<r;k++)buc[arr[k].dep]=max(buc[arr[k].dep],arr[k].sum),lim=max(lim,arr[k].dep);
	}
	buc[0]=0x80808080;
	for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to])cdp[edge[i].val]=ddp[edge[i].to]=0;
}
void solve(int x){
	calc(x);
	getsz(x,0); 
	vis[x]=true;
	for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to])ROOT=0,SZ=sz[edge[i].to],getroot(edge[i].to,0),solve(ROOT);
}
void read(int &x){
	x=0;
	char c=getchar();
	int fl=1;
	while(c>'9'||c<'0')fl=(c=='-'?-fl:fl),c=getchar();
	while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=getchar();
	x*=fl;
}
int main(){
	read(n),read(m),read(L),read(R),memset(head,-1,sizeof(head)),memset(buc,0x80,sizeof(buc));
	for(int i=1;i<=m;i++)read(val[i]);
	for(int i=1,x,y,z;i<n;i++)read(x),read(y),read(z),ae(x,y,z);
	msz[0]=n+1,SZ=n,getroot(1,0),solve(ROOT);
	printf("%d\n",mx);
	return 0;
}

原文地址:https://www.cnblogs.com/Troverld/p/14605822.html