Luogu 3206 [HNOI2010]城市建设

BZOJ 2001

很神仙的cdq分治

先放论文的链接   顾昱洲_浅谈一类分治算法

我们考虑分治询问,用$solve(l, r)$表示询问编号在$[l, r]$时的情况,那么当$l == r$的时候,直接把询问代入跑一个最小生成树就好了。

然而问题是怎么缩小每一层分治的规模,因为每一层都用$n$个点$m$条边来算稳$T$。

那么我们可以进行两个过程:

1、Reduction

  把与当前询问有关的边权设为$inf$跑最小生成树,那么此时不被连到最小生成树中的边一定是没什么用的,直接扔掉,这一步可以缩边。

2、Contraction

  把与当前询问有关的边权设为$-inf$跑最小生成树,那么不考虑边权为$-inf$的边连成的最小生成树的若干个连通块的边和点都是可以缩到一起的,这一步可以缩点。

这样子我们把这两个操作做完之后就可以做到把问题的规模缩减的与询问区间相关了。

时间复杂度$O(nlog^2n)$。

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
typedef pair <int, ll> pin;

const int N = 2e4 + 5;
const int M = 5e4 + 5;
const int Lg = 20;
const ll inf = 1LL << 60;

int n, m, qn, ufs[N], siz[N], pos[M], sum[Lg];
ll val[M], ans[M];
pin q[M];

struct Edge {
    int u, v, id;
    ll val;
    
    friend bool operator < (const Edge &x, const Edge &y) {
        return x.val < y.val;
    }
    
} e[Lg][M], c[M], t[M];

template <typename T>
inline void read(T &X) {
    X = 0; char ch = 0; T op = 1;
    for(; ch > '9' || ch < '0'; ch = getchar())
        if(ch == '-') op = -1;
    for(; ch >= '0' && ch <= '9'; ch = getchar())
        X = (X << 3) + (X << 1) + ch - 48;
    X *= op;
}

inline int find(int x) {
    return ufs[x] == x ? x : ufs[x] = find(ufs[x]);
}

inline void merge(int x, int y) {
    int fx = find(x), fy = find(y);
    if(fx == fy) return;
    if(siz[fx] < siz[fy]) ufs[fx] = fy, siz[fy] += siz[fx];
    else ufs[fy] = fx, siz[fx] += siz[fy];
}

inline void clear(int tot) {
    for(int i = 1; i <= tot; i++) {
        ufs[t[i].u] = t[i].u;
        ufs[t[i].v] = t[i].v;
        siz[t[i].u] = siz[t[i].v] = 1;
    }
}

inline void cont(int &tot, ll &nowVal) {
    int cnt = 0;
    clear(tot);
    sort(t + 1, t + 1 + tot);
    for(int i = 1; i <= tot; i++) {
        int u = find(t[i].u), v = find(t[i].v);
        if(u == v) continue;
        merge(u, v);
        c[++cnt] = t[i];
    }
    
    for(int i = 1; i <= cnt; i++) {
        ufs[c[i].u] = c[i].u;
        ufs[c[i].v] = c[i].v;
        siz[c[i].u] = siz[c[i].v] = 1;
    }
    
    for(int i = 1; i <= cnt; i++) {
        if(c[i].val == -inf) continue;
        int u = find(c[i].u), v = find(c[i].v);
        if(u == v) continue;
        merge(u, v);
        nowVal += c[i].val;    
    }
    
    cnt = 0;
    for(int i = 1; i <= tot; i++) {
        int u = find(t[i].u), v = find(t[i].v);
        if(u == v) continue;
        c[++cnt] = t[i];
        pos[t[i].id] = cnt;
        c[cnt].u = ufs[t[i].u];
        c[cnt].v = ufs[t[i].v];
    }
    
    for(int i = 1; i <= cnt; i++) t[i] = c[i];
    tot = cnt;
}

void redu(int &tot) {
    int cnt = 0;
    clear(tot);
    sort(t + 1, t + 1 + tot);
    for(int i = 1; i <= tot; i++) {
        if(find(t[i].u) != find(t[i].v)) {
            merge(t[i].u, t[i].v);
            c[++cnt] = t[i];
            pos[t[i].id] = cnt;
        } else if(t[i].val == inf) {
            c[++cnt] = t[i];
            pos[t[i].id] = cnt;
        }
    }
    
    for(int i = 1; i <= cnt; i++) t[i] = c[i];
    tot = cnt;
}

void solve(int l, int r, int now, ll nowVal) {
    int tot = sum[now];
    if(l == r) val[q[l].first] = q[l].second;
    for(int i = 1; i <= tot; i++)
        e[now][i].val = val[e[now][i].id];
    for(int i = 1; i <= tot; i++)
        t[i] = e[now][i], pos[e[now][i].id] = i;
    
    if(l == r) {
        ans[l] = nowVal;
        sort(t + 1, t + 1 + tot);
        clear(tot);
        for(int i = 1; i <= tot; i++) {
            int u = find(t[i].u), v = find(t[i].v);
            if(u == v) continue;
            merge(u, v);
            ans[l] += t[i].val;
        }
        return;
    }
    
    for(int i = l; i <= r; i++)
        t[pos[q[i].first]].val = -inf;
    cont(tot, nowVal);
    
    for(int i = l; i <= r; i++)
        t[pos[q[i].first]].val = inf;
    redu(tot);
    
    ++now;
    for(int i = 1; i <= tot; i++)
        e[now][i] = t[i];
    sum[now] = tot;
    
    int mid = (l + r) / 2;
    solve(l, mid, now, nowVal);
    solve(mid + 1, r, now, nowVal);
}

int main() {
//    freopen("1.in", "r", stdin);
    
    read(n), read(m), read(qn);
    for(int i = 1; i <= m; i++) {
        read(e[0][i].u), read(e[0][i].v), read(e[0][i].val);
        e[0][i].id = i;
        val[i] = e[0][i].val;
    }
    for(int i = 1; i <= qn; i++)
        read(q[i].first), read(q[i].second);
    
    sum[0] = m;
    solve(1, qn, 0, 0LL);
    
    for(int i = 1; i <= qn; i++)
        printf("%lld
", ans[i]);

    return 0;
}
View Code
原文地址:https://www.cnblogs.com/CzxingcHen/p/9866505.html