树链剖分 洛谷 3384

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cctype>

using namespace std;

inline int read()
{
    int x = 0;
    int k = 1;
    char c = getchar();
    
    while (!isdigit(c))
        if (c == '-') k = -1, c = getchar();
        else c = getchar();
    while (isdigit(c)) 
        x = (x << 3) + (x << 1) + (c ^ 48),
        c = getchar();
    
    return k * x;
}

struct edge
{
    int u;
    int v;
    int next;
}e[200020];

int n, m, r, mod, cnt, nid = 0;
int val[1000200];
int dep[1002000];
int siz[1002000];
int faz[1002000];
int top[1002000];
int son[1000200];
int tid[1000200];
int lid[1002000];
int nv[1002000];
int f[1002000];

void addedge(int x, int y)
{
    ++cnt;
    e[cnt].u = x;
    e[cnt].v = y;
    e[cnt].next = f[x];
    f[x] = cnt; 
}

void dfs1(int u, int father, int depth)
{
    faz[u] = father;
    dep[u] = depth;
    siz[u] = 1;
    
    for (int i = f[u]; i != -1; i = e[i].next)
    {
        int to = e[i].v;
        if (to == father) continue;
        dfs1(to, u, depth + 1);
        siz[u] += siz[to];
        if (son[u] == -1 || siz[son[u]] < siz[to]) son[u] = to; 
    }
}

void dfs2(int u, int h)
{
    ++nid;
    top[u] = h;
    nv[nid] = val[u];
    lid[u] = nid;
    
    if (son[u] == -1) return;
    
    dfs2(son[u], h);
    
    for (int i = f[u]; i != -1; i = e[i].next)
    {
        int to = e[i].v;
        if (!lid[to])
            dfs2(to, to);        
    }
}

//Segmenttree Begins!!!

#define ls u << 1
#define rs u << 1 | 1

struct tree
{
    int l;
    int r;
    int w;
    int siz;
    int f; 
}t[2000100];

void update(int u) { t[u].w = (t[ls].w + t[rs].w + mod) % mod; }

void build(int u, int l, int r)
{
    t[u].l = l;
    t[u].r = r;
    t[u].siz = r - l + 1;
    if (l == r)
    {
        t[u].w = nv[l];
        return;
    }
    int mid = (l + r) >> 1;
    build(ls, l, mid);
    build(rs, mid + 1, r);
    update(u);
}

void pushdown(int u)
{
    if (!t[u].f) return;
    t[ls].w = (t[ls].w + t[ls].siz * t[u].f) % mod;
    t[rs].w = (t[rs].w + t[rs].siz * t[u].f) % mod;
    t[ls].f = (t[ls].f + t[u].f) % mod;
    t[rs].f = (t[rs].f + t[u].f) % mod;
    t[u].f = 0;
}

void add(int u, int l, int r, int x)
{
    if (l <= t[u].l && r >= t[u].r)
    {
        t[u].w += t[u].siz * x;
        t[u].f += x;
        return;
    }
    pushdown(u);
    int mid = (t[u].l + t[u].r) >> 1;
    if (l <= mid) add(ls, l, r, x);
    if (r > mid) add(rs, l, r, x);
    update(u);
}

int sum(int u, int l, int r)
{
    int ans = 0;
    if (l <= t[u].l && t[u].r <= r) return t[u].w;
    pushdown(u);
    int mid = (t[u].l + t[u].r) >> 1;
    if (l <= mid) ans = (ans + sum(ls, l, r)) % mod;
    if (r > mid) ans = (ans + sum(rs, l, r)) % mod;
    return ans; 
}

//operator in tree sparate

void tadd(int x, int y, int c)
{
    while (top[x] != top[y])
    {
        if (dep[top[x]] < dep[top[y]]) swap(x, y);
        add(1, lid[top[x]], lid[x], c);
        x = faz[top[x]];
    }
    if (dep[x] > dep[y]) swap(x, y);
    add(1, lid[x], lid[y], c);
} 

void tsum(int x, int y)
{
    int ans = 0;
    while (top[x] != top[y])
    {
        if (dep[top[x]] < dep[top[y]]) swap(x, y);
        ans = (ans + sum(1, lid[top[x]], lid[x])) % mod;
        x = faz[top[x]];
    }
    if (dep[x] > dep[y]) swap(x, y);
    ans = (ans + sum(1, lid[x], lid[y])) % mod;
    printf("%d
", ans);
} 

int main()
{
    memset(f, -1, sizeof(f));
    memset(son, -1, sizeof(son));
    n = read();
    m = read();
    r = read();
    mod = read();
    for (int i = 1; i <= n; ++i) val[i] = read(); 
    for (int i = 1; i < n; ++i)
    {
        int x, y;
        x = read();
        y = read();
        addedge(x, y);
        addedge(y, x);
    }
    
    dfs1(r, 0, 1);
    dfs2(r, r);
    build(1, 1, n);
    
    for (int i = 1; i <= m; ++i)
    {
        int opr = read();
        int x, y, z;
        x = read();
        if (opr == 1)
        {
            y = read();
            z = read();
            z = z % mod;
            tadd(x, y, z);
        }
        else if (opr == 2)
        {
            y = read();
            tsum(x, y);
        }
        else if (opr == 3)
        {
            z = read();
            add(1, lid[x], lid[x] + siz[x] - 1, z % mod);
        }
        else if (opr == 4)
            printf("%d
", sum(1, lid[x], lid[x] + siz[x] - 1));
    }
    return 0;
}
原文地址:https://www.cnblogs.com/yanyiming10243247/p/9700500.html