[51nod 1766]树上的最远点对 (树的直径+ST表求lca+线段树)

[51nod 1766]树上的最远点对 (树的直径+ST表求lca+线段树)

题面

给出一棵N个点的树,Q次询问一点编号在区间[l1,r1]内,另一点编号在区间[l2,r2]内的所有点对距离最大值。(N, Q≤100000)

分析

看到区间,我们应该想到用线段树维护,区间[l,r]存储编号在[l,r]内的点组成的一棵树的直径端点和长度

考虑如何合并区间。设两个区间的直径分别为(a,b) (c,d),则新区间的直径端点肯定也是a,b,c,d中的一个。(证明显然),那么新区间的直径就是max(dist(a,b),dist(a,c),dist(a,d),dist(b,c),dist(b,d),dist(c,d))

那么直接线段树维护就行了,pushup的时候按上面的那样合并。最后查询得到[l1,r1]内的直径(a,b),[l2,r2]内的直径(c,d) ,答案就是max(dist(a,c),dist(b,d),dist(b,c),dist(a,d))

如果用树上倍增求lca,时间复杂度为(O(nlog^2n)),改用dfs序+ST表求lca,查询只需要在ST表中求最值,是O(1)的,时间复杂度(O(nlog n))

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define maxn 100000
#define maxlogn 25
using namespace std;
inline void qread(int &x) {
	x=0;
	int sign=1;
	char c=getchar();
	while(c<'0'||c>'9') {
		if(c=='-') sign=-1;
		c=getchar();
	}
	while(c>='0'&&c<='9') {
		x=x*10+c-'0';
		c=getchar();
	}
	x=x*sign;
}
inline void qprint(long long x) {
	if(x<0) {
		putchar('-');
		qprint(-x);
	} else if(x==0) {
		putchar('0');
		return;
	} else {
		if(x>=10) qprint(x/10);
		putchar(x%10+'0');
	}
}

int n,m;
struct edge {
	int from;
	int to;
	int next;
	int len;
} E[maxn*2+5];
int head[maxn+5];
int esz=1;
void add_edge(int u,int v,int w) {
	esz++;
	E[esz].from=u;
	E[esz].to=v;
	E[esz].len=w;
	E[esz].next=head[u];
	head[u]=esz;
}

int cnt=0;
long long dis[maxn+5];
int seq[maxn*2+5];
int first[maxn+5];
int deep[maxn*2+5];
void dfs(int x,int fa,int d) {
	seq[++cnt]=x;
	deep[cnt]=d;
	first[x]=cnt;
	for(int i=head[x]; i; i=E[i].next) {
		int y=E[i].to;
		if(y!=fa) {
			dis[y]=dis[x]+E[i].len;
			dfs(y,x,d+1);
			seq[++cnt]=x;
			deep[cnt]=d;
		}
	}
}

struct ST {//ST表求lca
	int log2[maxn*2+5];
	int f[maxn*2+5][maxlogn+5];
	void ini(int n) {
		log2[0]=-1;
		for(int i=1; i<=n; i++) log2[i]=log2[i>>1]+1;
		for(int i=1; i<=n; i++) f[i][0]=i;
		for(int j=1; (1<<j)<=n; j++) {
			for(int i=1; i+(1<<j)-1<=n; i++) {
				int x=f[i][j-1];
				int y=f[i+(1<<(j-1))][j-1];
				if(deep[x]<deep[y]) f[i][j]=x;
				else f[i][j]=y;
			}
		}
	}
	int query(int l,int r) {
		int k=log2[r-l+1];
		int x=f[l][k];
		int y=f[r-(1<<k)+1][k];
		if(deep[x]<deep[y]) return x;
		else return y;
	}
} S;
inline int lca(int x,int y) {
	x=first[x];
	y=first[y];
	if(x>y) swap(x,y);
	return seq[S.query(x,y)];
}
inline long long dist(int x,int y) {
	return dis[x]+dis[y]-2*dis[lca(x,y)];
}


struct value {//直径的端点与长度
	int u;
	int v;
	long long dis;
	value() {

	}
	value(int _u,int _v,int _dis) {
		u=_u;
		v=_v;
		dis=_dis;
	}
	inline void cal(int x,int y) {
		long long len=dist(x,y);
		if(len>dis) {
			dis=len;
			u=x;
			v=y;
		}
	}
	friend value operator + (value p,value q) {//合并答案
		if(p.u==0) return q;//注意一个为空的情况
		if(q.u==0) return p;
		int a=p.u,b=p.v,c=q.u,d=q.v;
		value ans=value(0,0,0);
		ans.cal(a,b);
		ans.cal(a,c);
		ans.cal(a,d);
		ans.cal(b,c);
		ans.cal(b,d);
		ans.cal(c,d);
		return ans;
	}
};

struct segment_tree {
#define lson(x) (x<<1)
#define rson(x) (x<<1|1)
	struct node {
		int l;
		int r;
		value val;
	} tree[maxn*4+5];
	void push_up(int x) {
		tree[x].val=tree[lson(x)].val+tree[rson(x)].val;//重载之后运算变得很简洁
	}
	void build(int l,int r,int pos) {
		tree[pos].l=l;
		tree[pos].r=r;
		tree[pos].val=value(0,0,0);
		if(l==r) {
			tree[pos].val=value(l,l,0);
			return;
		}
		int mid=(l+r)>>1;
		build(l,mid,pos<<1);
		build(mid+1,r,pos<<1|1);
		push_up(pos);
	}
	value query(int L,int R,int pos) {
		if(L<=tree[pos].l&&R>=tree[pos].r) {
			return tree[pos].val;
		}
		int mid=(tree[pos].l+tree[pos].r)>>1;
		value ans=value(0,0,0);
		if(L<=mid) ans=ans+query(L,R,pos<<1);
		if(R>mid) ans=ans+query(L,R,pos<<1|1);
		return ans;
	}
} T;


int main() {
	int u,v,w;
	int l1,r1,l2,r2;
	qread(n);
	for(int i=1; i<n; i++) {
		qread(u);
		qread(v);
		qread(w);
		add_edge(u,v,w);
		add_edge(v,u,w);
	}
	dfs(1,0,1);
	S.ini(cnt);
	T.build(1,n,1);
	
	qread(m);
	for(int i=1; i<=m; i++) {
		qread(l1);
		qread(r1);
		qread(l2);
		qread(r2);
		value ans1=T.query(l1,r1,1);
		value ans2=T.query(l2,r2,1);
		//ans=ans1+ans2 
		int a=ans1.u,b=ans1.v,c=ans2.u,d=ans2.v;//这里不能只选ans1或ans2,所以不能写ans1+ans2 
		qprint(max(max(dist(a,c),dist(b,d)),max(dist(a,d),dist(b,c))));
		putchar('
');
	}
}

原文地址:https://www.cnblogs.com/birchtree/p/11228605.html