[JSOI2008]最小生成树计数

1016: [JSOI2008]最小生成树计数

Time Limit: 1 Sec  Memory Limit: 162 MB
[Submit][Status][Discuss]

Description

  现在给出了一个简单无向加权图。你不满足于求出这个图的最小生成树,而希望知道这个图中有多少个不同的最小生成树。(如果两颗最小生成树中至少有一条边不同,则这两个最小生成树就是不同的)。由于不同的最小生成树可能很多,所以你只需要输出方案数对31011的模就可以了。

Input

  第一行包含两个数,n和m,其中1<=n<=100; 1<=m<=1000; 表示该无向图的节点数和边数。每个节点用1~n的整数编号。接下来的m行,每行包含两个整数:a, b, c,表示节点a, b之间的边的权值为c,其中1<=c<=1,000,000,000。数据保证不会出现自回边和重边。注意:具有相同权值的边不会超过10条。

Output

  输出不同的最小生成树有多少个。你只需要输出数量对31011的模就可以了。

Sample Input

4 6
1 2 1
1 3 1
1 4 1
2 3 2
2 4 1
3 4 1

Sample Output

8
 
最小生成树不同的边权用几个是确定的
不存在 用1和8替换掉5和3的情况,因为用5的那个地方用1更优
所以先用kruskal做一遍最小生成树,统计不同边权用了几个
由于相同边权的边不超过10条,
然后对每个边权的每条边暴力dfs用还是不用
所有边权的dfs结果累乘
#include<cstdio>
#include<cstring>
#include<algorithm>
#define mod 31011
using namespace std;
int num;
int fa[101];
struct node
{
    int u,v,w;
}e[1001];
struct edge
{
    int sum,val,l,r;
}g[1001];
bool cmp(node p,node q)
{
    return p.w<q.w;
}
int find(int i) { return fa[i]==i ? i : find(fa[i]); }
void dfs(int now,int tot,int i)
{
    if(now==g[i].r+1)
    {
        if(tot==g[i].sum) num++;
        return;
    }
    dfs(now+1,tot,i);
    int p=find(e[now].u),q=find(e[now].v);
    if(p!=q)
    {
        fa[p]=q;
        dfs(now+1,tot+1,i);
        fa[p]=p; 
    }
}
int main()
{
    int n,m,u,v,tot=0,cnt=0;
    int ans=1;
    scanf("%d%d",&n,&m);
    for(int i=1;i<=m;i++) scanf("%d%d%d",&e[i].u,&e[i].v,&e[i].w);
    for(int i=1;i<=n;i++) fa[i]=i;
    sort(e+1,e+m+1,cmp);
    int tmp;
    for(int i=1;i<=m;i++)
    {
        if(e[i].w!=g[cnt].val) g[cnt].r=i-1,g[++cnt].l=i,g[cnt].val=e[i].w;
        u=find(e[i].u); v=find(e[i].v);
        if(u!=v) fa[u]=fa[v],tot++,g[cnt].sum++;
        if(tot==n-1)  { tmp=i; break; }
    }
    if(tot!=n-1) { printf("0
"); return 0; }
    for(int i=tmp+1;i<=m;i++)
        if(e[i].w==g[cnt].val) tmp=i;
    g[cnt].r=tmp;
    for(int i=1;i<=n;i++) fa[i]=i;
    for(int i=1;i<=cnt;i++)
    {
        if(!g[i].sum) continue;
        num=0;
        dfs(g[i].l,0,i);
        ans=ans*num%mod;
        for(int j=g[i].l;j<=g[i].r;j++) 
        {
            u=find(e[j].u); v=find(e[j].v);
            if(u!=v) fa[u]=v;
        }
    }
    printf("%d",ans);
}
View Code

矩阵树定理版

#include<vector>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
int n,m,mod=31011;
int fa[101],ka[101];
struct node
{
    int u,v,w;
}e[1001];
int a[101][101];
bool vis[101];
vector<int>g[101];
long long ans,C[101][101],t;
bool cmp(node p,node q)
{
    return p.w<q.w;
}
int find(int i,int *f) { return f[i]==i ? i : find(f[i],f); }
void init()
{
    int u,v;
    scanf("%d%d",&n,&m);
    for(int i=1;i<=m;i++) scanf("%d%d%d",&e[i].u,&e[i].v,&e[i].w);
}
long long det(int h)
{
    long long s=1;
    for(int i=0;i<h;i++)
    {
        for(int j=i+1;j<h;j++)
         while(C[j][i])
         {
             t=C[i][i]/C[j][i];
             for(int k=i;k<h;k++) C[i][k]=(C[i][k]-C[j][k]*t+mod)%mod;
             for(int k=i;k<h;k++) swap(C[i][k],C[j][k]);
             s=-s;
         }
        s=s*C[i][i]%mod;
        if(!s) return 0;
    }
    return (s+mod)%mod;    
}
void matrix_tree()
{
    int len,u,v;
    for(int i=1;i<=n;i++)
        if(vis[i])
        {
            g[find(i,ka)].push_back(i);
            vis[i]=false;
        }
    for(int i=1;i<=n;i++)
        if(g[i].size()>1)
        {
            memset(C,0,sizeof(C));
            len=g[i].size();
            for(int j=0;j<len;j++)
                for(int k=j+1;k<len;k++)
                {
                    u=g[i][j]; v=g[i][k];
                    if(a[u][v])
                    {
                        C[k][j]=(C[j][k]-=a[u][v]);
                        C[k][k]+=a[u][v]; C[j][j]+=a[u][v];
                    }
                }
            ans=ans*det(g[i].size()-1)%mod;
            for(int j=0;j<len;j++) fa[g[i][j]]=i;
        }
    for(int i=1;i<=n;i++) 
    {
        g[i].clear();
        ka[i]=fa[i]=find(i,fa);
    }
}
void solve()
{
    ans=1;
    int u,v;
    for(int i=1;i<=n;i++) fa[i]=ka[i]=i;
    sort(e+1,e+m+1,cmp);
    for(int i=1;i<=m+1;i++)
    {
        if(e[i].w!=e[i-1].w && i!=1 || i==m+1) matrix_tree();
        u=find(e[i].u,fa); v=find(e[i].v,fa);
        if(u!=v)
        {
            vis[u]=vis[v]=true;
            ka[find(u,ka)]=find(v,ka);
            a[u][v]++; a[v][u]++;
        }
    }
    bool flag=true;
    for(int i=1;i<n && flag;i++) 
    if(fa[i]!=fa[i+1]) flag=false;
    printf("%lld
",flag ? ans%mod : 0);
}
int main()
{
    init();
    solve();
}
View Code
原文地址:https://www.cnblogs.com/TheRoadToTheGold/p/7419232.html