bzoj4637:期望

思路:最小生成树计数只不过加了一个期望,由于期望具有线性性质,就可以转化为每条边的期望之和,那么一条边的期望如何求呢,在最小生成树记数中,是把相同边权的一起处理,之后把属于连通块内的点缩点,也就是说,一条边只可能在它属于的连通块内对答案产生贡献,之后因为缩点而不会影响答案,因此一条边的期望就等于它在它所属的连通块内包含它的生成树个数除以那个连通块的生成树个数,而包含这条边的生成树个数就是该连通块内所有的生成树个数减去不包含这条边的生成树个数,然后用matrix-tree定理统计答案即可,因为这题要枚举边,所以最好写两个并查集,反正我之前的dfs写法没法写。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
using namespace std;
#define maxm 200005
#define maxn 10005
const long double eps=1e-9;
 
int n,m,cnt,top;
int pos[maxn],stack[maxn];
bool instack[maxn];
long long tot;
long double ans,K[1000][1000],T[1000][1000];
 
vector<int> v[maxn];
 
struct edge{
    int from,to,dis,val;
    bool operator <(const edge &a)const{return dis<a.dis;}
}e[maxm];
 
inline int read(){
    int x=0;char ch=getchar();
    for (;ch<'0'||ch>'9';ch=getchar());
    for (;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
    return x;
}
 
struct union_find_set{
    int fa[maxn];
    int find(int x){return fa[x]==x?x:fa[x]=find(fa[x]);}
}u1,u2;
 
long long gauss(){
    int t,n=cnt-1,f=1;long double ans=1;
    for (int i=1;i<n;i++){
        for (t=i;t<=n;t++) if (fabs(K[t][i])>eps) break;if (t>n) return 0;
        if (t!=i){for (int j=1;j<=n;j++) swap(K[i][j],K[t][j]);f=-f;}
        for (int j=i+1;j<=n;j++)
            if (fabs(K[j][i])>eps){
                long double t=K[j][i]/K[i][i];
                for (int k=i;k<=n;k++) K[j][k]-=K[i][k]*t;
            }
    }
    for (int i=1;i<=n;i++) ans=ans*K[i][i];
    return round(ans*f);
}
 
void add(int x,int y,int val){
    K[x][y]-=val,K[y][x]-=val;
    K[x][x]+=val,K[y][y]+=val;
}
 
int main(){
    n=read(),m=read();
    for (int i=1;i<=m;i++) e[i].from=read(),e[i].to=read(),e[i].dis=read(),e[i].val=read();
    for (int i=1;i<=n;i++) u1.fa[i]=u2.fa[i]=i; sort(e+1,e+m+1);
    for (int i=1,l=1;i<=m+1;i++){
        int x=u1.find(e[i].from),y=u1.find(e[i].to);
        if (x!=y){int u=u2.find(x),v=u2.find(y);if (u!=v) u2.fa[u]=v;}
        if (e[i].dis!=e[i+1].dis){
            for (int j=l;j<=i;j++){
                int x=u1.find(e[j].from),y=u1.find(e[j].to);
                if (x==y) continue; int u=u2.find(x);
                if (!instack[u]) stack[++top]=u,instack[u]=1;
            }
            while (top){
                instack[stack[top]]=0,cnt=0;
                for (int j=l;j<=i;j++){
                    int x=u1.find(e[j].from),y=u1.find(e[j].to);
                    if (x==y) continue; int u=u2.find(x);
                    if (u==stack[top]){
                        if (!pos[x]) pos[x]=++cnt;
                        if (!pos[y]) pos[y]=++cnt;
                        add(pos[x],pos[y],1);
                    }
                }
                for (int a=1;a<=cnt;a++)
                    for (int b=1;b<=cnt;b++)
                        T[a][b]=K[a][b];
                tot=gauss();
                for (int a=1;a<=cnt;a++)
                    for (int b=1;b<=cnt;b++)
                        K[a][b]=T[a][b];
                for (int j=l;j<=i;j++){
                    int x=u1.find(e[j].from),y=u1.find(e[j].to);
                    if (x==y) continue; int u=u2.find(x);
                    if (u==stack[top]){
                        for (int a=1;a<=cnt;a++)
                            for (int b=1;b<=cnt;b++)
                                T[a][b]=K[a][b];
                        add(pos[x],pos[y],-1);
                        long long tmp=gauss();
                        for (int a=1;a<=cnt;a++)
                            for (int b=1;b<=cnt;b++)
                                K[a][b]=T[a][b];
                        ans+=1.0*(tot-tmp)/tot*e[j].val;
                    }
                }
                for (int j=l;j<=i;j++){
                    int x=u1.find(e[j].from),y=u1.find(e[j].to);
                    if (x==y) continue;pos[x]=pos[y]=0;
                }
                for (int j=1;j<=cnt;j++)
                    for (int k=1;k<=cnt;k++)
                        K[j][k]=0;
                top--;
            }
            for (int j=l;j<=i;j++){
                int x=u1.find(e[j].from),y=u1.find(e[j].to);
                if (x==y) continue;u1.fa[x]=y;
            }
            l=i+1;
        }
    }
    printf("%.5lf",(double)ans);
    return 0;
}
原文地址:https://www.cnblogs.com/DUXT/p/6024252.html