树上的等差数列 [树形dp]

树上的等差数列

题目描述

给定一棵包含 (N) 个节点的无根树,节点编号 (1 o N) 。其中每个节点都具有一个权值,第 (i) 个节点的权值是 (A_i)

(Hi) 希望你能找到树上的一条最长路径,满足沿着路径经过的节点的权值序列恰好构成等差数列。

输入格式

第一行包含一个整数 (N)

第二行包含 (N) 个整数 (A_1, A_2, ... A_N)

以下 (N-1) 行,每行包含两个整数 (U)(V) ,代表节点 (U)(V) 之间有一条边相连。

输出格式

最长等差数列路径的长度

样例

样例输入

7  
3 2 4 5 6 7 5  
1 2  
1 3  
2 7  
3 4  
3 5  
3 6

样例输出

4

数据范围与提示

对于 (50\%) 的数据,(1 leqslant N leqslant 1000)

对于 (100\%) 的数据,(1 leqslant N leqslant 100000, 0 leqslant A_i leqslant 100000, 1 leqslant U, V leqslant N)

分析

树形 (dp) 好题。

因为要求的是最长的等差序列,根节点不同,答案也可能不同,所以 (dp) 的状态转移就定义为 (f[i][j]) 表示 (i) 节点为根,公差为 (j) 时的最长的等差数列,不包括自己。那么我们就可以愉快的 (dfs) 来进行转移了。

我们记录一下他自己和他的父亲,避免出现死循环,每一次先 (dfs) 到儿子,递归上来,然后就处理出来了公差为 (Delta) 的以儿子为根的所有长度,这时候我们只需要判断一下此时的 (Delta) 值是否为 (0)。如果是,那么 (ans) 的转移应该是:

[ans = max(ans,f[x][0] + f[son[x]][0] + 2) ]

因为此时 (f[x][0]) 存储的是其他儿子上最长链,所以需要加上当前儿子的最长链,因为我们的数组不保存自己,所以要加 (2)

其他情况就是直接更新 (ans) ,他的答案应该是 (f[x][d] + f[x][-d] + 1) ,因为他的父亲那里也可能会有链,公差为 (-d) 就是那个链,由于负数下标的问题,我们利用 (map) 来存储,然后轻松解决此题。

代码

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<map>
#define re register
using namespace std;
const int maxn = 1e5+10;
map <int,int> mp[maxn];
struct Node{
	int v,next;
}e[maxn<<1];
int w[maxn];
int ans = 0;
int head[maxn],tot;
void Add(int x,int y){//建边
	e[++tot].v = y;
	e[tot].next = head[x];
	head[x] = tot;
}
inline int read(){//快读
	int s = 0,f = 1;
	char ch = getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){s=s*10+ch-'0';ch=getchar();}
	return s * f;
}
inline void DP(int x,int fa){
	for(int i=head[x];i;i=e[i].next){
		int v = e[i].v;
		if(v == fa)continue;//避免死循环
		int d = w[v] - w[x];//计算公差
		DP(v,x);
		if(!d){//公差为0的情况
			ans = max(ans,mp[x][0] + mp[v][0] + 2);
			mp[x][0] = max(mp[x][0],mp[v][0] + 1);
		}
		else{//公差不为0
			mp[x][d] = max(mp[x][d],mp[v][d] + 1);
			ans = max(ans,mp[x][d] + mp[x][-d] + 1);
		}
	}
}

int main(){
	freopen("C.in","r",stdin);
	freopen("C.out","w",stdout);
	int n =read();
	for(re int i = 1;i<=n;++i){w[i]=read();}
	for(re int i = 1;i< n;++i){
		int x = read(),y = read();
		Add(x,y);
		Add(y,x);
	}
	DP(1,0);
	printf("%d
",ans);
}
原文地址:https://www.cnblogs.com/Vocanda/p/13504501.html