51nod1819 黑白树V2

简单的题面

给定一棵以1为根的有根树,点可能是黑色或白色,操作如下。


1. 选定一个点x,将x的子树中所有到x的距离为奇数的点的颜色反转。
2. 选定一个点x,将点x的颜色反转。
3. 选定一个点x,询问所有黑点y(包括点x)与点x的lca(最近公共祖先)的和。

果然自己码一码收获挺大的....

首先考虑怎么回答3操作,不妨考虑枚举$lca$

如果$lca = x$,可以发现,在$x$子树内的答案都是$x$

否则,根节点到$x$形成了一条链,令其为$1 o x_1 o x_2 ...... o x$

可以发现,$x_i$对答案的贡献为$(sz[x_i] -  sz[x_{i + 1}]) * x_i$

由于答案分布在一条链上,考虑使用轻重链剖分

考虑到2操作和1操作

用线段树动态的维护

$sz[i][2][2], col[i]$

分别表示

1.$i$节点子树中,深度为奇 / 偶数,颜色为 黑 / 白的节点数

2.$i$节点的颜色

以及

$sum[2][2]$表示区间内所有$g[i][2][2]$的和

其中$g[i][x][y] = sz[i][x][y] * i$

还有一大堆的细节,直接看代码吧,无法用言语来描述

$51nod ;rank1$,$hhh$

还是有$8kb$......应该还能短点的

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;

extern inline char gc() {
    static char RR[23456], *S = RR + 23333, *T = RR + 23333;
    if(S == T) fread(RR, 1, 23333, stdin), S = RR;
    return *S ++;
}
inline int read() {
    int p = 0, w = 1; char c = gc();
    while(c > '9' || c < '0') { if(c == '-') w = -1; c = gc(); }
    while(c >= '0' && c <= '9') p = p * 10 + c - '0', c = gc();
    return p * w;
}

int wr[50], rw;
char WR[40000005], *I = WR;
#define pc(z) *I ++ = (z)
template <typename re>
inline void write(re x) {
    if(!x) pc('0');
    if(x < 0) pc('-'), x = -x;
    while(x) wr[++ rw] = x % 10, x /= 10;
    while(rw) pc(wr[rw --] + '0'); pc('
');
}

#define fe float
#define de double
#define le long double
#define ll long long
#define ui unsigned int
#define ri register int
#define ull unsigned long long
#define sid 200050
#define eid 400050

int n, m, cnp, did;
int dfn[sid], ord[sid], anc[sid], col[sid];
int cap[sid], node[eid], nxt[eid];
int sz[sid], dep[sid], pre[sid], fa[sid];

inline void adeg(int u, int v) {
    nxt[++ cnp] = cap[u]; cap[u] = cnp; node[cnp] = v;
}

#define cur node[i]
inline void dfs(int o) {
    sz[o] = 1;
    for(int i = cap[o]; i; i = nxt[i])
    if(cur != fa[o]) {
        fa[cur] = o; dep[cur] = dep[o] + 1;
        dfs(cur); sz[o] += sz[cur]; 
        if(sz[pre[o]] < sz[cur]) pre[o] = cur;
    }
}

inline void dfs(int o, int tp) {
    anc[o] = tp; dfn[++ did] = o; ord[o] = did;
    if(pre[o]) dfs(pre[o], tp); else return;
    for(int i = cap[o]; i; i = nxt[i])
    if(cur != fa[o] && cur != pre[o]) dfs(cur, cur);
}

int f[sid][2][2], g[sid][2][2];
inline void dp(int o) {
    f[o][dep[o] & 1][col[o]] = 1;
    for(int i = cap[o]; i; i = nxt[i])
    if(cur != fa[o]) {
        dp(cur);
        for(ri d = 0; d <= 1; d ++)
        for(ri c = 0; c <= 1; c ++)
        f[o][d][c] += f[cur][d][c];
    }
    for(ri d = 0; d <= 1; d ++)
    for(ri c = 0; c <= 1; c ++)
    g[o][d][c] = f[o][d][c] - f[pre[o]][d][c];
}

struct Seg {
    ll s[2][2];
    int rev[2], mas[2][2];
} t[sid * 4];

#define ls (o << 1)
#define rs (o << 1 | 1)

inline void update(int o) {
    for(ri i = 0; i <= 1; i ++)
    for(ri j = 0; j <= 1; j ++)
    t[o].s[i][j] = t[ls].s[i][j] + t[rs].s[i][j];
}

inline void build(int o, int l, int r) {
    if(l == r) {
        int x = dfn[l];
        for(ri i = 0; i <= 1; i ++)
        for(ri j = 0; j <= 1; j ++)
        t[o].s[i][j] = 1ll * g[x][i][j] * x;
        return;
    }
    int mid = (l + r) >> 1;
    build(ls, l, mid); build(rs, mid + 1, r);
    update(o);
}

inline void prev(int o, int d, int l, int r) {
    if(l == r) {
        int x = dfn[l];
        if((dep[x] & 1) == d) col[x] ^= 1;
        swap(f[x][d][0], f[x][d][1]);
    }
    swap(t[o].s[d][0], t[o].s[d][1]);
    swap(t[o].mas[d][0], t[o].mas[d][1]);
    t[o].rev[d] ^= 1;
}

inline void premas(int o, int d, int c, int v, int l, int r) {
    if(l == r) {
        int x = dfn[l];
        f[x][d][c] -= v; f[x][d][c ^ 1] += v;
    }
    t[o].mas[d][c] += v;
}

inline void pushdown(int o, int l, int r) {
    int mid = (l + r) >> 1;
    for(ri i = 0; i <= 1; i ++)
    if(t[o].rev[i]) {
        t[o].rev[i] = 0;
        prev(ls, i, l, mid); 
        prev(rs, i, mid + 1, r);
    }
    for(ri i = 0; i <= 1; i ++)
    for(ri j = 0; j <= 1; j ++)
    if(t[o].mas[i][j] != 0) {
        premas(ls, i, j, t[o].mas[i][j], l, mid);
        premas(rs, i, j, t[o].mas[i][j], mid + 1, r);
        t[o].mas[i][j] = 0;
    }
}

inline void rev(int o, int l, int r, int ml, int mr, int d) {
    if(ml > mr) return;
    if(ml > r || mr < l) return;
    if(ml <= l && mr >= r) { prev(o, d, l, r); return; }
    int mid = (l + r) >> 1;
    pushdown(o, l, r); 
    rev(ls, l, mid, ml, mr, d);
    rev(rs, mid + 1, r, ml, mr, d);
    update(o);
}

inline void mas(int o, int l, int r, int ml, int mr, int d, int c, int v) {
    if(ml > mr) return;
    if(ml > r || mr < l) return;
    if(ml <= l && mr >= r) { premas(o, d, c, v, l, r); return; }
    int mid = (l + r) >> 1;
    pushdown(o, l, r);
    mas(ls, l, mid, ml, mr, d, c, v);
    mas(rs, mid + 1, r, ml, mr, d, c, v);
    update(o);
}

inline void mis(int o, int l, int r, int p, int d, int c, int v) {
    if(l == r) { 
        f[p][d][c ^ 1] += v; f[p][d][c] -= v;
        t[o].s[d][c ^ 1] += 1ll * v * p; t[o].s[d][c] -= 1ll * v * p;
        return; 
    }
    int mid = (l + r) >> 1;
    pushdown(o, l, r);
    if(ord[p] <= mid) mis(ls, l, mid, p, d, c, v);
    else mis(rs, mid + 1, r, p, d, c, v);
    update(o);
}

inline int qc(int o, int l, int r, int v) {
    if(l == r) return col[v];
    int mid = (l + r) >> 1;
    pushdown(o, l, r);
    if(ord[v] <= mid) return qc(ls, l, mid, v);
    else return qc(rs, mid + 1, r, v);
}

inline int dsz(int o, int l, int r, int v, int d) {
    if(l == r) return f[v][d][1] - f[v][d][0];
    int mid = (l + r) >> 1;
    pushdown(o, l, r);
    if(ord[v] <= mid) return dsz(ls, l, mid, v, d);
    else return dsz(rs, mid + 1, r, v, d);
}

inline int qsz(int o, int l, int r, int v) {
    if(l == r) return f[v][0][1] + f[v][1][1];
    int mid = (l + r) >> 1;
    pushdown(o, l, r);
    if(ord[v] <= mid) return qsz(ls, l, mid, v);
    else return qsz(rs, mid + 1, r, v);
}

inline ll qs(int o, int l, int r, int ml, int mr) {
    if(ml > mr) return 0;
    if(ml > r || mr < l) return 0;
    if(ml <= l && mr >= r) return t[o].s[0][1] + t[o].s[1][1];
    int mid = (l + r) >> 1;
    pushdown(o, l, r);
    return qs(ls, l, mid, ml, mr) + qs(rs, mid + 1, r, ml, mr);
}

inline void change(int x) { 
    int d = (dep[x] + 1) & 1, ff = anc[x];
    int der = dsz(1, 1, n, x, d);
    rev(1, 1, n, ord[x], ord[x] + sz[x] - 1, d);
    mas(1, 1, n, ord[ff], ord[x] - 1, d, 1, der);
    for(ri i = anc[fa[ff]], j = fa[ff]; j; j = fa[i], i = anc[j])
    mis(1, 1, n, j, d, 1, der), mas(1, 1, n, ord[i], ord[j] - 1, d, 1, der);
}

inline void put(int x) {
    int ff = anc[x];
    col[x] = qc(1, 1, n, x);
    int d = dep[x] & 1, c = col[x];
    mis(1, 1, n, x, d, c, 1);
    mas(1, 1, n, ord[ff], ord[x] - 1, d, c, 1);
    for(ri i = anc[fa[ff]], j = fa[ff]; j; j = fa[i], i = anc[j])
    mis(1, 1, n, j, d, c, 1), mas(1, 1, n, ord[i], ord[j] - 1, d, c, 1);
    col[x] ^= 1;
}

inline ll query(int x) {
    ll ans = 0;
    int f = anc[x];
    ans += 1ll * qsz(1, 1, n, x) * x;
    for(ri i = f, j = x, o = x; j; j = fa[i], o = i, i = anc[j]) {
        if(j != o) ans += 1ll * (qsz(1, 1, n, j) - qsz(1, 1, n, o)) * j;
        ans += qs(1, 1, n, ord[i], ord[j] - 1);
    }
    return ans;
}

int main() {
    n = read(); m = read();
    for(ri i = 1; i <= n; i ++) col[i] = read();
    for(ri i = 1; i < n; i ++) {
        int u = read(), v = read();
        adeg(u, v); adeg(v, u);
    } 
    dfs(1); dfs(1, 1); 
    dp(1); build(1, 1, n);
    for(ri i = 1; i <= m; i ++) {
        int opt = read(), x = read();
        if(opt == 1) change(x);
        if(opt == 2) put(x);
        if(opt == 3) write(query(x));
    }
    fwrite(WR, 1, I - WR, stdout);
    return 0;
}
原文地址:https://www.cnblogs.com/reverymoon/p/9486271.html