[BJOI2018]求和

题目概述

题目描述

(master) 对树上的求和非常感兴趣。

他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的(k) 次方和,而且每次的(k) 可能是不同的。

此处节点深度的定义是这个节点到根的路径上的边数。

他把这个问题交给了(pupil),但(pupil) 并不会这么复杂的操作,你能帮他解决吗?

输入输出格式

输入格式

第一行包含一个正整数(n),表示树的节点数。

之后(n-1) 行每行两个空格隔开的正整数(i, j),表示树上的一条连接点(i) 和点(j) 的边。

之后一行一个正整数(m),表示询问的数量。

之后每行三个空格隔开的正整数(i, j, k),表示询问从点(i) 到点(j) 的路径上所有节点深度的(k) 次方和。

由于这个结果可能非常大,输出其对(998244353) 取模的结果。

树的节点从(1) 开始标号,其中(1) 号节点为树的根。

输出格式

对于每组数据输出一行一个正整数表示取模后的结果。

输入输出样例

输入样例 #1
5
1 2
1 3
2 4
2 5
2
1 4 5
5 4 45
输出样例 #1
33
503245989

样例解释

以下用(d (i)) 表示第(i) 个节点的深度。 对于样例中的树,有(d (1) = 0, d (2) = 1, d (3) = 1, d (4) = 2, d (5) = 2)
因此第一个询问答案为((2^5 + 1^5 + 0^5) mod 998244353 = 33)
第二个询问答案为((2^{45} + 1^{45} + 2^{45}) mod 998244353 = 503245989)

数据范围

对于(30\%) 的数据,(1 leq n,m leq 100)
对于(60\%) 的数据,(1 leq n,m leq 1000)
对于(100\%) 的数据,(1 leq n,m leq 300000, 1 leq k leq 50)

友情提示

数据规模较大,请注意使用较快速的输入输出方式。

解题报告

题意理解

  1. 要你求出一条路经上,每一个点的深度的(k)次方之和
  2. $ 1 le k le 50$

算法解析

首先看到下面,这些话,你就可以判断本题目算法为最近公共祖先

  1. 一棵树上,大量两点路径查询
  2. 没有任何修改操作。
  3. 树上节点很多,要求复杂度不高

综上所述,我们发现这道题目满足所有的要求,因此我们可以推断出,本题目使用最近公共祖先。


我们再来分析这道题目如何使用LCA算法。

我们观察数据范围,得到(k)的值域很小。

因此我们不妨开一个长度为(50)大小的数组,存储每一个(k)对应的树。

[sum[j][i]表示从根节点到i,当前k是j的情况下的深度和 \\ 比如说sum[3][2]表示1~2这条路径上,每一个点的3次方的总和 ]

然后我们可以通过树上差分算法,解决查询问题。

[sum[k][a]+sum[k][b]-sum[k][fa[Lca]]-sum[k][Lca] ]

这里就不画图,如果要看的话,就瞅瞅我的树上差分专题吧,上面有这道题目的图,或者康康我的讲课视频,链接


代码解析

#include <bits/stdc++.h>
using namespace std;
const int N=300100,Mod=998244353;
int fa[N][21],n,m;
long long sum[51][N],deep[N],power[51];
vector<int> g[N];
void dfs(int x,int s)
{
	for(int i=1; i<=20; i++)
		fa[x][i]=fa[fa[x][i-1]][i-1];
	for(int y:g[x])
	{
		if (y==s)
			continue;
		deep[y]=deep[x]+1;
		fa[y][0]=x;
		for(int k=1; k<=50; k++)
			power[k]=power[k-1]*deep[y]%Mod;
		for(int k=1; k<=50; k++)
			sum[k][y]=(power[k]+sum[k][x])%Mod;
		dfs(y,x);
	}
}
inline int Lca(int a,int b)
{
	if (deep[a]<deep[b])
		swap(a,b);
	for(int i=20; i>=0; i--)
		if (deep[fa[a][i]]>=deep[b])
			a=fa[a][i];
	if (a==b)
		return a;
	for(int i=20; i>=0; i--)
		if (fa[a][i]!=fa[b][i])
			a=fa[a][i],b=fa[b][i];
	return fa[a][0];
}
inline void init()
{
//	freopen("data.in","r",stdin);
//	freopen("a.out","w",stdout);
	scanf("%d",&n);
	for(int i=1; i<n; i++)
	{
		int a,b;
		scanf("%d%d",&a,&b);
		g[a].push_back(b);
		g[b].push_back(a);
	}
	power[0]=1;
	dfs(1,1);
	scanf("%d",&m);
	for(int i=1; i<=m; i++)
	{
		int k,a,b;
		scanf("%d%d%d",&a,&b,&k);
		int LCA=Lca(a,b);
		long long ans=sum[k][a]+Mod+sum[k][b]+Mod-sum[k][fa[LCA][0]]-sum[k][LCA];
		ans%=Mod;
		printf("%lld
",ans);
	}
}
signed main()
{
	init();
	return 0;
}
原文地址:https://www.cnblogs.com/gzh-red/p/11831583.html