LOJ 2491 求和 (LCA + 前缀和)

Loj.2491

题意:
给一棵有根树,对Q次询问,每次输入x,y,k。输出树上x到y的路径上点的深度的k次方和。

思路:
树上两点间路径的权值和很容易想到LCA, 然后发现可以预处理深度的k次方的前缀和。对每个x和lca之间点的深度肯定是连续和,其深度k次方和(不算lca点)是sum[d[x]][k] - sum[d[lca]][k]。

代码:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<map>
#include<queue>
#include<vector>
#include<string>
#include<bitset>
#include<fstream>
using namespace std;
#define rep(i, a, n) for(int i = a; i <= n; ++ i);
#define per(i, a, n) for(int i = n; i >= a; -- i);
typedef long long ll;
typedef pair<int,int> PII;
const int N = 1e6 + 105;
const int mod = 998244353;
const double Pi = acos(- 1.0);
const int INF = 0x3f3f3f3f; 
const int G = 3, Gi = 332748118;
ll qpow(ll a, ll b) { ll res = 1; while(b){ if(b & 1) res = (res * a) % mod; a = (a * a) % mod; b >>= 1;} return res; }
ll gcd(ll a, ll b) { return b ? gcd(b, a % b) : a; }
// bool cmp(int a, int b){return a > b;}
//

int n;
int head[N], cnt = 0;
int to[N << 1], nxt[N << 1];
ll sum[N][60];

void add(int u, int v){
    to[cnt] = v, nxt[cnt] = head[u], head[u] = cnt ++;
    to[cnt] = u, nxt[cnt] = head[v], head[v] = cnt ++;
}


//lca
int t;
int d[N], dist[N], f[N][20];

queue<int> q;

void bfs(){
    q.push(1);
    d[1] = 0;
    while(q.size()){
        int u = q.front(); q.pop();
        for(int i = head[u]; i != -1; i = nxt[i]){
            int v = to[i];
            if(d[v] != -1) continue;
            d[v] = d[u] + 1;
            f[v][0] = u;
            for(int j = 1; j <= t; ++ j)
                f[v][j] = f[f[v][j - 1]][j - 1];
            q.push(v);
        }
    }
}

int Lca(int x,int y)
{
    //调整到同样高度
    if(d[x] > d[y]) swap(x, y);
    for(int i = t; i >= 0; -- i)
        if(d[f[y][i]] >= d[x]) y = f[y][i];
    //特殊情况
    if(x == y) return x;
    //一般情况
    for(int i = t; i >= 0; -- i)
        if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
    return f[x][0];
}


int main()
{
    scanf("%d",&n);
    cnt = 0;
    memset(d, -1, sizeof(d));
    for(int i = 0; i <= n; ++ i) head[i] = -1;
    for(int i = 1; i <= n; ++ i){
        ll tp = 1;
        for(int j = 1; j <= 50; ++ j){
            tp = tp * 1ll * i % mod;
            sum[i][j] = (sum[i - 1][j] + tp) % mod;
        }
    }
    for(int i = 1; i < n; ++ i){
        int x, y; scanf("%d%d",&x,&y);
        add(x, y);
    }
    t=(int)(log(n)/log(2))+1;
    bfs();
    int Q; scanf("%d",&Q);
    while(Q --){
        int x, y; ll k; scanf("%d%d%lld",&x,&y,&k);
        int lca = Lca(x, y);
        ll res = ((sum[d[x]][k] + sum[d[y]][k] - sum[d[lca]][k] * 2ll % mod + mod) % mod + qpow(d[lca], k)) % mod; 
        printf("%lld
",res);
    }
    return 0;
}

原文地址:https://www.cnblogs.com/A-sc/p/13758228.html