loj #2179. 「BJOI2017」树的难题 点分治+线段树

对于树上统计路径的问题我们通常要用到点分治来搞一搞。

首先我们点分治。

摄当前的分治中心是 x,那么把 x 周围的点按照颜色排个序。

统计的时候我们建两颗线段树,设当前处理到的 x 周围的点是 y,x 和 y 之间的点的颜色是 z ,那么第一颗线段树是 z 之前的颜色(不包括z),第二课线段树是 z。

每棵线段树以到 x 距离为下表,存的是到 x 这段路程的权值。

那么新统计到一个点的时候,在第一棵线段树我们直接加,第二课线段树加的时候减去拼接时候的损失。

时间复杂度 (O(nlog^2n))

细节较多。

#include<algorithm>
#include<iostream>
#include<cstdio>
#include<vector>
#define lson (k<<1)
#define rson ((k<<1)|1)
using namespace std;
int n,m,l,r,tot;
const int N=200010,inf=2e9;
int c[N];
struct bian
{
	int to,c;
	friend bool operator <(const bian &a,const bian &b){return a.c<b.c;}
};
vector<bian>v[N];
inline int read()
{
    int res = 0; char ch = getchar(); bool XX = false;
    for (; !isdigit(ch); ch = getchar())(ch == '-') && (XX = true);
    for (; isdigit(ch); ch = getchar())res = (res << 3) + (res << 1) + (ch ^ 48);
    return XX ? -res : res;
}
namespace solve1
{
	int ans=-2e9;
	void dfs(int x,int fa,int dep,int sum,int last)
	{
		if(l<=dep&&dep<=r)ans=max(ans,sum);
		if(dep>r)return;
		for(int i=0,Siz=v[x].size();i<Siz;++i)
			if(v[x][i].to!=fa)dfs(v[x][i].to,x,dep+1,sum+(v[x][i].c==last?0:c[v[x][i].c]),v[x][i].c);
	}
	void work()
	{
		for(int i=1;i<=n;++i)dfs(i,0,0,0,0);
		cout<<ans;
	}
}
struct XDS
{
	int tr[N<<2];
	void pushup(int k)
	{
		tr[k]=max(tr[lson],tr[rson]);
	}
	void build(int k,int l,int r)
	{
		if(l==r)
		{
			tr[k]=-inf;
			return;
		}
		int mid=(l+r)>>1;
		build(lson,l,mid);build(rson,mid+1,r);
		pushup(k);
	}
	void change(int k,int l,int r,int pos,int val)
	{
		if(l==r)
		{
			tr[k]=max(tr[k],val);
			return;
		}
		int mid=(l+r)>>1;
		if(pos<=mid)change(lson,l,mid,pos,val);
		else change(rson,mid+1,r,pos,val);
		pushup(k);
	}
	void clear(int k,int l,int r,int pos)
	{
		if(l==r)
		{
			tr[k]=-inf;
			return;
		}
		int mid=(l+r)>>1;
		if(pos<=mid)clear(lson,l,mid,pos);
		else clear(rson,mid+1,r,pos);
		pushup(k);
	}
	int ask(int k,int l,int r,int x,int y)
	{
		if(x<=l&&r<=y)return tr[k];
		int mid=(l+r)>>1,res=-inf;
		if(x<=mid)res=max(res,ask(lson,l,mid,x,y));
		if(mid+1<=y)res=max(res,ask(rson,mid+1,r,x,y));
		return res;
	}
}pre,now;
namespace solve2
{
	int root,num,ans=-inf;
	int vis[N],siz[N],mx[N];
	void Groot(int x,int fa)
	{
		siz[x]=1;mx[x]=0;
		for(int i=0,Siz=v[x].size();i<Siz;++i)
			if(!vis[v[x][i].to]&&v[x][i].to!=fa)
			{
				Groot(v[x][i].to,x);
				siz[x]+=siz[v[x][i].to];mx[x]=max(mx[x],siz[v[x][i].to]);
			}
		mx[x]=max(mx[x],num-siz[x]);
		if(mx[x]<mx[root])root=x;
	}
	void dfs1(int x,int fa,int dep,int sum,int last,int se)
	{
		if(r-dep<=0)return;
		if(l<=dep&&dep<=r)ans=max(ans,sum);
		ans=max(ans,sum+pre.ask(1,1,n,max(1,l-dep),r-dep));
		ans=max(ans,sum+now.ask(1,1,n,max(1,l-dep),r-dep)-c[se]);
		for(int i=0,Siz=v[x].size();i<Siz;++i)
			if(!vis[v[x][i].to]&&v[x][i].to!=fa)dfs1(v[x][i].to,x,dep+1,sum+(v[x][i].c==last?0:c[v[x][i].c]),v[x][i].c,se);
	}
	void dfs2(int x,int fa,int dep,int sum,int last)
	{
		if(r-dep<=0)return;
		now.change(1,1,n,dep,sum);
		for(int i=0,Siz=v[x].size();i<Siz;++i)
			if(!vis[v[x][i].to]&&v[x][i].to!=fa)dfs2(v[x][i].to,x,dep+1,sum+(v[x][i].c==last?0:c[v[x][i].c]),v[x][i].c);
	}
	void dfs3(int x,int fa,int dep,int sum,int last)
	{
		if(r-dep<=0)return;
		now.clear(1,1,n,dep);pre.change(1,1,n,dep,sum);
		for(int i=0,Siz=v[x].size();i<Siz;++i)
			if(!vis[v[x][i].to]&&v[x][i].to!=fa)dfs3(v[x][i].to,x,dep+1,sum+(v[x][i].c==last?0:c[v[x][i].c]),v[x][i].c);
	}
	void dfs4(int x,int fa,int dep,int sum,int last)
	{
		if(r-dep<=0)return;
		pre.clear(1,1,n,dep);
		for(int i=0,Siz=v[x].size();i<Siz;++i)
			if(!vis[v[x][i].to]&&v[x][i].to!=fa)dfs4(v[x][i].to,x,dep+1,sum+(v[x][i].c==last?0:c[v[x][i].c]),v[x][i].c);
	}
	void dfs5(int x,int fa,int dep,int sum,int last)
	{
		if(r-dep<=0)return;
		now.clear(1,1,n,dep);
		for(int i=0,Siz=v[x].size();i<Siz;++i)
			if(!vis[v[x][i].to]&&v[x][i].to!=fa)dfs5(v[x][i].to,x,dep+1,sum+(v[x][i].c==last?0:c[v[x][i].c]),v[x][i].c);
	}
	void solve(int x)
	{
		vis[x]=1;
		int Siz=v[x].size();
		sort(v[x].begin(),v[x].end());
		
		for(int i=0;i<Siz;++i)
			if(!vis[v[x][i].to])
			{
				if(i!=0&&v[x][i].c!=v[x][i-1].c)
				{
					for(int j=i-1;j>=0&&v[x][j].c==v[x][i-1].c;--j)
						if(!vis[v[x][j].to])dfs3(v[x][j].to,x,1,c[v[x][j].c],v[x][j].c);
				}
				dfs1(v[x][i].to,x,1,c[v[x][i].c],v[x][i].c,v[x][i].c);
				dfs2(v[x][i].to,x,1,c[v[x][i].c],v[x][i].c);
			}
		for(int i=0;i<Siz;++i)
			if(!vis[v[x][i].to])dfs4(v[x][i].to,x,1,c[v[x][i].c],v[x][i].c);
		for(int i=Siz-1;i>=0&&v[x][i].c==v[x][Siz-1].c;--i)
			if(!vis[v[x][i].to])dfs5(v[x][i].to,x,1,c[v[x][i].c],v[x][i].c);
		
		for(int i=0;i<Siz;++i)
			if(!vis[v[x][i].to])root=0,num=siz[v[x][i].to],Groot(v[x][i].to,0),solve(root);
	}
	void work()
	{
		pre.build(1,1,n);now.build(1,1,n);
		mx[0]=1<<30;root=0;num=n;Groot(1,0);solve(root);
		cout<<ans;
	}
}
int main()
{
	cin>>n>>m>>l>>r;
	for(int i=1;i<=m;++i)c[i]=read();
	for(int i=1,x,y,z;i<n;++i)
	{
		x=read(),y=read(),z=read();
		v[x].push_back((bian){y,z});
		v[y].push_back((bian){x,z});
	}
	if(n<=1000)solve1::work();
	else solve2::work();
	return 0;
}
原文地址:https://www.cnblogs.com/wljss/p/13370827.html