xxx

#include<stdio.h>
#include<string.h>
#include<algorithm>
#include<stdlib.h>
//#include<iostream>
using namespace std;

#define LL long long
int n,m,q;
#define maxn 20011
#define maxm 100011
struct Edge{int to,v,next;bool imp;};
const int inf=0x3f3f3f3f;
struct Graph
{
    Edge edge[maxm<<1];int first[maxn],le,n;
    void clear(int m) {n=m;memset(first,0,sizeof(first));le=2;}
    void in(int x,int y,int v) {Edge &e=edge[le];e.to=y;e.v=v;e.next=first[x];first[x]=le++;}
    void insert(int x,int y,int v) {in(x,y,v);in(y,x,v);}
    
    struct Point
    {
        int from,to,v,id;
        bool operator < (const Point &b) const {return v<b.v;}
    }p[maxm];int lp;
    bool ok[maxm];
    struct Ufs
    {
        int fa[maxn];
        void clear(int n) {for (int i=1;i<=n;i++) fa[i]=i;}
        int find(int x) {return x==ufs[x]?x:(ufs[x]=find(ufs[x]));}
        void Union(int x,int y)
        {
            x=find(x),y=find(y);
            if (x==y) return;
            fa[x]=y;
        }
    }ufs,kk;
    LL mst()
    {
        lp=0;
        for (int i=1;i<=n;i++)
            for (int j=first[i];j;j=edge[j].next)
                p[++lp]=(Point){i,edge[j].to,edge[j].v,j};
        sort(p+1,p+1+lp);
        ufs.clear(n);
        LL ans=0;int tot=1;
        memset(ok,0,sizeof(ok));
        for (int i=1;i<=lp;i++)
        {
            if (ufs.find(p[i].from)==ufs.find(p[i].to)) continue;
            ok[i]=1;
            ufs.Union(p[i].from,p[i].to);
            ans+=p[i].v;tot++;
            if (tot==n) break;
        }
        return ans;
    }
    LL Key;int bel[maxn];
}g[20];

int Q[20][maxm];
void contraction(int cur,int L,int R)
{
    for (int i=L;i<=R;i++)
        g[cur].edge[Q[cur][i]].v=g[cur].edge[Q[cur][i]^1].v=-inf;
    g[cur].mst();
    g[cur].kk.clear(g.n);
    for (int i=1;i<=g.lp;i++) if (ok[i]) g[cur].kk.Union(g[cur].p[i].from,g[cur].p[i].to);
}

void solve(int L,int R,int cur)
{
    if (L==R)
    {
        printf("%lld
",mst());
        return 0;
    }
    contraction(cur,L,R);
    reduction(cur,L,R);
    build(cur,cur+1);
    const int mid=(L+R)>>1;
    solve(L,mid);
    solve(mid+1,R);
}
    
int main()
{
    scanf("%d%d%d",&n,&m,&q);
    g[0].clear(n);
    for (int i=1,x,y,v;i<=m;i++)
    {
        scanf("%d%d%d",&x,&y,&v);
        g[0].insert(x,y,v);
    }
    for (int i=1,x;i<=q;i++)
    {
        scanf("%d",&x);
        Q[0][i]=x<<1;
    }
    solve(1,q,0);
    return 0;
}
原文地址:https://www.cnblogs.com/Blue233333/p/7892555.html