题目描述
master 对树上的求和非常感兴趣。他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的kk 次方和,而且每次的kk 可能是不同的。此处节点深度的定义是这个节点到根的路径上的边数。他把这个问题交给了pupil,但pupil 并不会这么复杂的操作,你能帮他解决吗?
输入格式
第一行包含一个正整数(n),表示树的节点数。
之后(n-1) 行每行两个空格隔开的正整数(i, j),表示树上的一条连接点(i)和点(j)的边。
之后一行一个正整数(m)表示询问的数量。
之后每行三个空格隔开的正整数(i, j, k),表示询问从点ii 到点jj 的路径上所有节点深度的(k) 次方和。由于这个结果可能非常大,输出其对(998244353) 取模的结果。
树的节点从(1) 开始标号,其中(1)号节点为树的根。
输出格式
对于每组数据输出一行一个正整数表示取模后的结果。
输入输出样例
输入 #1复制
5
1 2
1 3
2 4
2 5
2
1 4 5
5 4 45
输出 #1复制
33
503245989
说明/提示
样例解释
以下用(d (i)) 表示第ii 个节点的深度。
对于样例中的树,有(d (1) = 0, d (2) = 1, d (3) = 1, d (4) = 2, d (5) = 2)
因此第一个询问答案为((2^5 + 1^5 + 0^5) mod 998244353),第二个询问答案为((2^{45} + 1^{45} + 2^{45}) mod 998244353 = 503245989)
数据范围
对于(30\%) 的数据,(1 leq n,m leq 100)
对于(60\%) 的数据,(1 leq n,m leq 1000)
对于(100\%) 的数据,(1 leq n,m leq 300000, 1 leq k leq 50)
另外存在5个不计分的hack数据
提示
数据规模较大,请注意使用较快速的输入输出方式。
敲完树剖求lca华丽走人
我们可以发现,lca的情况无非就是三种
1.(lca==a)
2.(lca==b)
3.(lca)在(a)和(b)的上面
1,2情况直接暴力跳就行,
3.情况分别从(a)向(lca)和从(b)向(lca)跳,然后我们发现(lca)算了两次,然后再减去一次(lca)的贡献就行
#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
const int M=400100;
const int N=400100;
int ne[M],head[M],ver[M],idx;
int dep[N],fa[N],son[N],sz[N],top[N];
long long ans;
int n,m;
inline void add(int u,int v)
{
ne[idx]=head[u];
ver[idx]=v;
head[u]=idx;
idx++;
}
inline void dfs1(int u,int father,int depth)
{
fa[u]=father;
sz[u]=1;
dep[u]=depth;
for(int i=head[u]; i!=-1; i=ne[i])
{
int j=ver[i];
if(j==father)continue;
dfs1(j,u,depth+1);
sz[u]+=sz[j];
if(sz[son[u]]<sz[j]) son[u]=j;
}
}
inline void dfs2(int u,int t)
{
top[u]=t;
if(!son[u]) return ;
dfs2(son[u],t);
for(int i=head[u]; i!=-1; i=ne[i])
{
int j=ver[i];
if(j==fa[u]||j==son[u])continue;
dfs2(j,j);
}
}
inline int lca(int u,int v)
{
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]])
swap(u,v);
u=fa[top[u]];
}
if(dep[u]<dep[v]) swap(u,v);
return v;
}
inline int qmi(int a,int b)
{
int ans=1;
while(b)
{
if(b&1) ans=(long long)ans*a%mod;
a=(long long)a*a%mod;
b>>=1;
}
return ans;
}
inline int read()
{
int x=0;
int f=1;
char ch;
ch=getchar();
while(ch>'9'||ch<'0')
{
if(ch=='-')f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=x*10,x=x+ch-'0';
ch=getchar();
}
return x*f;
}
int main()
{
memset(head,-1,sizeof(head));
n=read();
for(register int i=1; i<n; i++)
{
int u,v;
u=read();
v=read();
add(u,v);
add(v,u);
}
dfs1(1,0,0);
dfs2(1,1);
m=read();
for(register int i=1; i<=m; i++)
{
int a,b,k;
ans=0;
a=read();
b=read();
k=read();
int LCA=lca(a,b);
if(LCA==a)
{
for(register int j=dep[a]; j<=dep[b]; j++)
{
ans=(ans+qmi(j,k)+mod)%mod;
}
}
else if(LCA==b)
{
for(register int j=dep[b]; j<=dep[a]; j++)
{
ans=(ans+qmi(j,k)+mod)%mod;
}
}
else
{
for(register int j=dep[LCA]; j<=dep[b]; j++)
{
ans=(ans+qmi(j,k)%mod+mod)%mod;
}
for(register int j=dep[LCA]; j<=dep[a]; j++)
{
ans=(ans+qmi(j,k)%mod+mod)%mod;
}
ans=(ans-qmi(dep[LCA],k)%mod+mod)%mod;
}
printf("%lld
",ans);
}
return 0;
}