[洛谷P3976] TJOI2015 旅游

问题描述

为了提高智商,ZJY 准备去往一个新世界去旅游。这个世界的城市布局像一棵树,每两座城市之间只有一条路径可以互达。

每座城市都有一种宝石,有一定的价格。ZJY 为了赚取最高利益,她会选择从 A 城市买入再转手卖到 B 城市。

由于ZJY买宝石时经常卖萌,因而凡是 ZJY 路过的城市,这座城市的宝石价格会上涨。让我们来算算 ZJY 旅游完之后能够赚取的最大利润。(如 A 城市宝石价格为 v,则ZJY出售价格也为 v)

输入格式

第一行输入一个正整数 n 表示城市个数

接下来一行输入 n 个正整数表示每座城市宝石的最初价格 p,每个宝石的初始价格不超过 100。

第三行开始连续输入 n-1 行,每行有两个数字 x 和 y。表示 x 城市和 y 城市有一条路径。城市编号从1开始。

下一行输入一个正整数 q 表示询问次数。

接下来 q 行每行输入三个正整数 a,b,v,表示 ZJY 从 a 旅游到 b,城市宝石上涨 v。

输出格式

对于每次询问,输出 ZJY 可能获得的最大利润,如果亏本了则输出 0。

样例输入

3
1 2 3
1 2
2 3
2
1 2 100
1 3 100

样例输出

1
1

数据范围

(1 le n,q le 5 imes 10^4)

链接

洛谷

解析

首先考虑如何在序列上求最大差值顺序对。用线段树维护,每个节点中记录区间最大值、最小值、顺序对的最大差值(记为答案),合并时取两个子区间的答案和右区间最大值减左区间最小值这两者的最大值。再把这个做法移动到树上。首先树链剖分,对于一条路径(u,v),我们可以把它拆分成u到LCA和LCA到v两条路径。对于u到LCA的路径,我们可以在跳重链时记录前面所有重链中的最小值,然后用当前重链中的最大值减去这个最小值以及这条重链的答案更新总答案。对于LCA到v的路径,我们从v开始跳,记录前面重链中的最大值,然后是一样的道理。注意LCA到v的路径的方向是与u到LCA的路径相反的,所以线段树维护时要同时维护顺序对和逆序对(逆序对的维护可以参考顺序对)。修改用树链剖分即可。

代码

#include <iostream>
#include <cstdio>
#define N 50002
using namespace std;
struct node{
	int maxx,minx,dat1,dat2;
};
struct SegmentTree{
	int minx,maxx,dat1,dat2,add;
}t[N*4];
int head[N],ver[N*2],nxt[N*2],l;
int n,q,i,a[N],son[N],size[N],dep[N],fa[N],top[N],in[N],pos[N],cnt;
node operator + (node a,node b){
	node ans;
	ans.maxx=max(a.maxx,b.maxx);
	ans.minx=min(a.minx,b.minx);
	ans.dat1=max(max(a.dat1,b.dat1),a.maxx-b.minx);
	ans.dat2=max(max(a.dat2,b.dat2),b.maxx-a.minx);
	return ans;
}
int read()
{
	char c=getchar();
	int w=0;
	while(c<'0'||c>'9') c=getchar();
	while(c<='9'&&c>='0'){
		w=w*10+c-'0';
		c=getchar();
	}
	return w;
}
void insert(int x,int y)
{
	l++;
	ver[l]=y;
	nxt[l]=head[x];
	head[x]=l;
}
void dfs1(int x,int pre)
{
	size[x]=1;dep[x]=dep[pre]+1;fa[x]=pre;
	for(int i=head[x];i;i=nxt[i]){
		int y=ver[i];
		if(y!=pre){
			dfs1(y,x);
			size[x]+=size[y];
			if(size[y]>size[son[x]]) son[x]=y;
		}
	}
}
void dfs2(int x,int t)
{
	top[x]=t;
	in[x]=++cnt;pos[cnt]=x;
	if(son[x]) dfs2(son[x],t);
	for(int i=head[x];i;i=nxt[i]){
		int y=ver[i];
		if(y!=fa[x]&&y!=son[x]) dfs2(y,y);
	}
}
void update(int p)
{
	t[p].maxx=max(t[p*2].maxx,t[p*2+1].maxx);
	t[p].minx=min(t[p*2].minx,t[p*2+1].minx);
	t[p].dat1=max(max(t[p*2].dat1,t[p*2+1].dat1),t[p*2].maxx-t[p*2+1].minx);
	t[p].dat2=max(max(t[p*2].dat2,t[p*2+1].dat2),t[p*2+1].maxx-t[p*2].minx);
}
void spread(int p)
{
	if(t[p].add){
		t[p*2].maxx+=t[p].add;t[p*2+1].maxx+=t[p].add;
		t[p*2].minx+=t[p].add;t[p*2+1].minx+=t[p].add;
		t[p*2].add+=t[p].add;t[p*2+1].add+=t[p].add;
		t[p].add=0;
	}
}
void build(int p,int l,int r)
{
	t[p].minx=t[p].maxx=a[pos[l]];
	if(l==r) return;
	int mid=(l+r)/2;
	build(p*2,l,mid);build(p*2+1,mid+1,r);
	update(p);
}
void change(int p,int l,int r,int ql,int qr,int x)
{
	if(ql<=l&&r<=qr){
		t[p].maxx+=x;t[p].minx+=x;
		t[p].add+=x;
		return;
	}
	spread(p);
	int mid=(l+r)/2;
	if(ql<=mid) change(p*2,l,mid,ql,qr,x);
	if(qr>mid) change(p*2+1,mid+1,r,ql,qr,x);
	update(p);
}
node ask(int p,int l,int r,int ql,int qr)
{
	if(ql<=l&&r<=qr) return (node){t[p].maxx,t[p].minx,t[p].dat1,t[p].dat2};
	int mid=(l+r)/2;
	node ans=(node){0,1<<30,0};
	spread(p);
	if(ql<=mid) ans=ask(p*2,l,mid,ql,qr)+ans;
	if(qr>mid) ans=ans+ask(p*2+1,mid+1,r,ql,qr);
	return ans;
}
int Ask(int u,int v)
{
	int maxx=0,minx=1<<30,ans=0;
	while(top[u]!=top[v]){
		if(dep[top[u]]>dep[top[v]]){
			node tmp=ask(1,1,n,in[top[u]],in[u]);
			ans=max(ans,max(tmp.dat1,tmp.maxx-minx));
			minx=min(minx,tmp.minx);
			u=fa[top[u]];
		}
		else{
			node tmp=ask(1,1,n,in[top[v]],in[v]);
			ans=max(ans,max(tmp.dat2,maxx-tmp.minx));
			maxx=max(maxx,tmp.maxx);
			v=fa[top[v]];
		}
	}
	if(dep[u]>dep[v]){
		node tmp=ask(1,1,n,in[v],in[u]);
		ans=max(ans,max(tmp.dat1,tmp.maxx-minx));
		minx=min(minx,tmp.minx);
	}
	else{
		node tmp=ask(1,1,n,in[u],in[v]);
		ans=max(ans,max(tmp.dat2,maxx-tmp.minx));
		maxx=max(maxx,tmp.maxx);
	}
	return max(ans,maxx-minx);
}
void Change(int u,int v,int w)
{
	while(top[u]!=top[v]){
		if(dep[top[u]]<dep[top[v]]) swap(u,v);
		change(1,1,n,in[top[u]],in[u],w);
		u=fa[top[u]];
	}
	if(dep[u]<dep[v]) swap(u,v);
	change(1,1,n,in[v],in[u],w);
}
int main()
{
	n=read();
	for(i=1;i<=n;i++) a[i]=read();
	for(i=1;i<n;i++){
		int u=read(),v=read();
		insert(u,v);
		insert(v,u);
	}
	dfs1(1,0);dfs2(1,1);
	build(1,1,n);
	q=read();
	while(q--){
		int u=read(),v=read(),w=read();
		printf("%d
",Ask(u,v));
		Change(u,v,w);
	}
	return 0;
}

原文地址:https://www.cnblogs.com/LSlzf/p/13264099.html