树链剖分

树链剖分

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 3e5 + 9;
template <typename T>
inline void read(T &x) {
    register T f = 0, c = getchar();
    for (; c < 48 || 57 < c; c = getchar())
        if (c == '-') f = 1;
    for (x = 0; 48 <= c && c <= 57; c = getchar())
        x = (x << 3) + (x << 1) + (c & 15);
    if (f) x = ~(--x);
}
template <typename T>
inline void write(T x) {
    if (x < 0) putchar('-'), x = ~(--x);
    if (x > 9) write(x / 10);
    putchar(x % 10 | 48);
}
ll n, m, root, mod;
struct segmentTree {
    struct node {
        int l, r, L, R;
        ll add, sum;
    }tr[N << 2];
    inline void pushup(int p) {tr[p].sum = (tr[tr[p].l].sum*1ll + tr[tr[p].r].sum*1ll)%mod;}
    inline void pushdown(int p) {
        if (tr[p].add) {
            ll d = tr[p].add;
            tr[p].add = 0;
            (tr[tr[p].l].sum += d %mod * (tr[tr[p].l].R - tr[tr[p].l].L + 1)%mod)%=mod;
            (tr[tr[p].r].sum += d %mod * (tr[tr[p].r].R - tr[tr[p].r].L + 1)%mod)%=mod;
            (tr[tr[p].l].add += d)%=mod;
            (tr[tr[p].r].add += d)%=mod;
        }
    }
    inline void build(int l, int r, int p) {
        tr[p].l = p<<1;
        tr[p].r = p<<1|1;
        tr[p].L = l, tr[p].R = r;
        if (l == r) {
            tr[p].sum = 0, tr[p].add = 0;
            return;
        }
        ll mid = l + r >> 1;
        build(l, mid, tr[p].l),build(mid + 1, r, tr[p].r);
        pushup(p);
    }
    inline void add(ll l, ll r, ll p, ll k) {
        if (tr[p].L >= l && tr[p].R <= r) {
            tr[p].add += k;
            tr[p].add %= mod;
            tr[p].sum += (tr[p].R - tr[p].L + 1)%mod * k%mod;
            tr[p].sum %= mod;
            return;
        }
        pushdown(p);
        if (tr[tr[p].l].R >= l)
        add(l, r, tr[p].l, k);
        if (tr[tr[p].r].L <= r)
        add(l, r, tr[p].r,k);
        pushup(p);
    }
    inline int ask(int l, int r, int p) {
        int ret = 0;
        pushdown(p);
        if (tr[p].L >= l && tr[p].R <= r) return tr[p].sum%mod;
        if (tr[tr[p].l].R >= l) (ret += ask(l, r, tr[p].l))%=mod;
        if (tr[tr[p].r].L <= r) (ret += ask(l, r, tr[p].r))%=mod;
        return ret;
    }
}T;
vector<int>G[N];
int size[N], fa[N], dep[N], son[N];
int id[N], top[N], w[N];
int cnt;
void dfs1(int u, int fat) {
    size[u] = 1;
    fa[u] = fat;
    if (fat != -1)
    dep[u] = dep[fat] + 1;
    int t = -1;
    for (auto v:G[u]) {
        if (v == fat)continue;
        dfs1(v, u);
        size[u] += size[v];
        if (size[v] > t) {
            t = size[v];
            son[u] = v;
        }
    }
}
void dfs2(int u, int fat) {
    top[u] = fat;
    id[u] = ++ cnt;
    if (w[u] != 0) {
        T.add(id[u], id[u], 1, w[u]);
    }
    if (son[u] == 0)return;
    dfs2(son[u], fat);
    for (auto v:G[u]) {
        if (v == fa[u] || v == son[u])continue;
        dfs2(v, v);
    }
}
void addPath(int u, int v, int k) {
    k %= mod;
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]])
        swap(u, v);
        T.add(id[top[u]], id[u],1, k);
        u = fa[top[u]];
    }
    if (dep[u] > dep[v]) swap(u, v);
    T.add(id[u], id[v], 1,k);
}
int askPath(int u, int v) {
    ll ret = 0;
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        ret += T.ask(id[top[u]], id[u], 1);
        ret %= mod;
        u = fa[top[u]];
    }
    if (dep[u] > dep[v]) swap(u, v);
    ret += T.ask(id[u], id[v], 1);
    ret %= mod;return ret;
}
void addSon(int u, int k) {k %= mod;T.add(id[u], id[u] + size[u] - 1,1, k);}
int querySon(int u) { return T.ask(id[u], id[u] + size[u] - 1, 1); }
void solve() {
    read(n), read(m), read(root), read(mod);
    T.build(1, n, 1);
    ll x;
    for (int i = 1; i <= n; i ++) {
        read(w[i]);
        w[i] %= mod;
    }
    for (int i = 1; i < n; i ++) {
        ll u, v;
        read(u), read(v);
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dfs1(root, 0);
    dfs2(root, root);
    while (m--) {
        ll op;
        read(op);
        if (op == 1) {
            ll u, v, z;
            read(u), read(v), read(z);
            z %= mod;
            addPath(u, v, z);
            
        }
        if (op == 2) {
            ll u, v;read(u), read(v);
            write(askPath(u, v)%mod);
            puts("");
        }
        if (op == 3) {
            int x, z;read(x), read(z);
            z %= mod;
            addSon(x, z);
        }
        if (op == 4) {
            int x;
            read(x);
            write(querySon(x) %mod);
            puts("");
        }
    }
}
signed main() {
    ll t = 1;//cin >> t;
    while (t--)  solve();
    return 0;
}
原文地址:https://www.cnblogs.com/Xiao-yan/p/14956084.html