luoguP4208 最小生成树计数
【注】:代码还未添加注释,有时间补
如果还不会生成树计数,看 -> 生成树计数
例题:luoguP4208
题意:
求一个图中有多少个不同的最小生成树。
题解:
两个非常重要的定理: 先假设A,B分别表示图G的两个不同的最小生成树。
- 定理一: 设ai,bi分别为A,B表示的MST中按权值从小到大排序的边,则(w(a_{i})=w(b_{i}))
- 定理二: A,B都从0开始加边直到构成MST,每条边加完后图的联通性相同。
我们先回忆如何去用Kruskal算法求MST,先找最小的边,再依此找大一点的边,每次如果边两边的点不在一个联通块中这条边就是MST中的一条边。根据定理一知图G的所有MST中各权值的边数量相等。举一个例子,这个例子会贯穿整个后面的部分(所有举例仅感性理解,正确性不保证):假设构成A的边为1,2,2,3,3。那么定理一想说明什么呢? 想说明的是B,乃至所有G的MST都是由1, 2, 2, 3, 3这样的边的组合构成的。
再假设G中边为1的有5条,边权为2的有3条,边权为3的有2条。那么我们现在有一个模糊的想法是能不能试着求出图的各权值的边分别有多少种方式去参与构成MST。有点像5个上衣,3个裤子,2个帽子,问你出去玩带1个上衣,2个裤子,2个帽子的方法数。直接(C_{5}^{1}*C_{3}^{2}*C_{2}^{2}=15) 种。但MST计数不能这么做,因为不能保证你5个1种选的这1个边能构成MST。相当于5个上衣,3个裤子,2个帽子,问你出去玩带1个上衣,2个裤子,2个帽子,且带的东西要构成一个叫“MST”的效果的方法数。
根据定理二,当我把某个边处理完后,对于所有的MST联通性都是一样的。(解释一下联通性一样:各图都由相同数量的联通块组成,每个联通块中的节点也都相同)。定理二想说明什么呢?说明如果我在构造所有G的MST时,把除了某一特定边权的边外对其他边权的边进行Kruskal算法,得到的新的图的联通性一样(注意得到的图可能不同)。还可以得出我对所有这些新的图的联通块进行缩点,最后得到的图都是一样的。因为本身各图的差异性体现在联通块种的连边方式可能不同,但缩点把各联通块种的点看成一个点,那这种块内连边不同这一差异性就没有了,所以缩点后图都一样。没错这一点非常重要。
有了这一点我只需要对除了一个边权外其他边跑Kruskal算法,再缩点,就能把所有其他边的可能组合变成一个图。这个图中一个点表示一个联通块,2个点间不连通。对这个图求生成树数量,相当于把所有未联通的联通块联通了,这时再考虑缩点前图的差异性,缩点后的图的生成树,一定能把缩点前所有有差异的图都联通。那就是这个边权的参与构成MST的方案数,且不管其他边怎么组合,在这些方案中选一个都能形成MST。
只要为我们每次对每个边权 w[i] ,把除这个边权的其他边处理完的图进行缩点,再在缩点后的图上计算生成树数量,就是w[i]的参与构成MST的方案数,所有方案数相乘就是图最小生成树数量。
我把代码分为 个步骤:
- 输入部分和初始化
- kruskal算法建立MST
- 用num[]记录各权值的边的数量
- 计算每个权值的边的参与构成MST的方案数,用val[i]表示
- 最后计算结果
定义/输入/初始化
int n, m;
int root[N];
ll val[N];
int vis[N], num[N];
int k[N][N];
map<int, int> mp;
struct Edge{
int u, v, c;
}edge[N];
//输入和初始化部分
memset(vis, 0, sizeof(vis));
scanf("%d%d",&n,&m);
init();
for(int i = 1; i <= m; ++ i){
scanf("%d%d%lld", &edge[i].u, &edge[i].v, &edge[i].c);
}
sort(edge + 1, edge + m + 1, cmp);
kruskal算法建立MST
int tp = 0;
for(int i = 1; i <= m; ++ i){
int u = edge[i].u, v = edge[i].v;
int tu = Find(u), tv = Find(v);
if(tu != tv){
tp ++;
vis[i] = 1;
if(tu > tv) swap(tu, tv);
root[tv] = tu;
}
}
if(tp != n - 1) return 0;
用num[]记录所有权值的边的数量
int temp = edge[1].c;
int numval = 0, tot = 0;
for(int i = 1; i <= m; ++ i){
ll w = edge[i].c;
if(w == temp) numval ++;
else{
num[++ tot] = numval;
temp = w;
numval = 1;
}
}
num[++ tot] = numval;
计算每个权值的边的贡献
int l = 1, r = 0;
edge[m + 1].c = -1;
int numm = tot; tot = 0;
for(int i = 1; i <= numm; ++ i){
r = l + num[i] - 1;
int sum = 0; //sum记录当前区间给初始最小生成树贡献的边数
for(int j = l; j <= r; ++ j){
sum += vis[j];
}
if(l == r || sum <= 0){
l = r + 1;
val[++ tot] = 1;
continue;
}
memset(k, 0, sizeof(k));
solve(l, r);
int cnt = 0;
int flag = 0;
for(int j = l; j <= r; ++ j){
int u = edge[j].u, v = edge[j].v;
int tu = Find(u), tv = Find(v);
if(tu != tv){
flag = 1;
if(!mp[tu]) mp[tu] = ++ cnt;
if(!mp[tv]) mp[tv] = ++ cnt;
k[mp[tu]][mp[tv]] --; k[mp[tv]][mp[tu]] --;
k[mp[tu]][mp[tu]] ++; k[mp[tv]][mp[tv]] ++;
}
}
if(!flag){
l = r + 1;
val[++ tot] = 1;
continue;
}
if(cnt == 2) val[++ tot] = k[2][2];
else val[++ tot] = gauss(cnt);
l = r + 1;
}
最后计算结果
// 最后计算结果
ll res = 1;
for(int i = 1; i <= tot; ++ i){
if(!val[i]) continue;
res = (res * val[i]) % mod;
}
return res;
完整代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<map>
#include<queue>
#include<vector>
#include<string>
#include<fstream>
using namespace std;
#define rep(i, a, n) for(int i = a; i <= n; ++ i)
#define per(i, a, n) for(int i = n; i >= a; -- i)
typedef long long ll;
typedef pair<ll, int>PII;
const int N = 1e3 + 5;
const ll mod = 31011;
const double Pi = acos(- 1.0);
const int INF = 0x3f3f3f3f;
const int G = 3, Gi = 332748118;
ll qpow(ll a, ll b) { ll res = 1; while(b){ if(b & 1) res = (res * a) % mod; a = (a * a) % mod; b >>= 1;} return res; }
ll gcd(ll a, ll b) { return b ? gcd(b, a % b) : a; }
//
int n, m;
int root[N];
ll val[N];
int vis[N], num[N];
int k[N][N];
map<int, int> mp;
struct Edge{
int u, v, c;
}edge[N];
bool cmp(Edge a, Edge b){
return a.c < b.c;
}
int Find(int x){
return x == root[x] ? x : Find(root[x]);
}
ll gauss(int n){
if(n == 3) return abs(k[2][2] * k[3][3] - (k[2][3] * k[3][2]));
ll res = 1;
for(int i = 1; i <= n - 1; ++ i){
for(int j = i + 1; j <= n - 1; ++ j){
while(k[j][i]){
ll t = k[i][i] / k[j][i];
for(int z = i; z <= n - 1; ++ z){
k[i][z] = (k[i][z] - t * k[j][z] % mod + mod) % mod;
swap(k[i][z], k[j][z]);
}
res = -res;
}
}
res = (res * k[i][i]) % mod;
}
return (res + mod) % mod;
}
void init(){
for(int i = 0; i <= n; ++ i){
root[i] = i;
}
mp.clear();
}
void solve(int l, int r){
init();
for(int i = 1; i <= m; ++ i){
if((i < l || i > r) && vis[i]){
int u = edge[i].u, v = edge[i].v;
int tu = Find(u), tv = Find(v);
if(tu != tv){
if(tu > tv) swap(tu, tv);
root[tv] = tu;
}
}
}
}
ll work(){
//kruskal建立最小生成树部分,vis1[]标记建立最小生成树的边
int tp = 0;
for(int i = 1; i <= m; ++ i){
int u = edge[i].u, v = edge[i].v;
int tu = Find(u), tv = Find(v);
if(tu != tv){
tp ++;
vis[i] = 1;
if(tu > tv) swap(tu, tv);
root[tv] = tu;
}
}
if(tp != n - 1) return 0;
//用num[]记录所有权值的边的数量
int temp = edge[1].c;
int numval = 0, tot = 0;
for(int i = 1; i <= m; ++ i){
ll w = edge[i].c;
if(w == temp) numval ++;
else{
num[++ tot] = numval;
temp = w;
numval = 1;
}
}
num[++ tot] = numval;
//计算每个权值的边的贡献,val[i]表示权值为i的这些边对答案即最小生成树个数的贡献
int l = 1, r = 0;
edge[m + 1].c = -1;
int numm = tot; tot = 0;
for(int i = 1; i <= numm; ++ i){
r = l + num[i] - 1;
int sum = 0; //sum记录当前区间给初始最小生成树贡献的边数
for(int j = l; j <= r; ++ j){
sum += vis[j];
}
if(l == r || sum <= 0){
l = r + 1;
val[++ tot] = 1;
continue;
}
memset(k, 0, sizeof(k));
solve(l, r);
int cnt = 0;
int flag = 0;
for(int j = l; j <= r; ++ j){
int u = edge[j].u, v = edge[j].v;
int tu = Find(u), tv = Find(v);
if(tu != tv){
flag = 1;
if(!mp[tu]) mp[tu] = ++ cnt;
if(!mp[tv]) mp[tv] = ++ cnt;
k[mp[tu]][mp[tv]] --; k[mp[tv]][mp[tu]] --;
k[mp[tu]][mp[tu]] ++; k[mp[tv]][mp[tv]] ++;
}
}
if(!flag){
l = r + 1;
val[++ tot] = 1;
continue;
}
if(cnt == 2) val[++ tot] = k[2][2];
else val[++ tot] = gauss(cnt);
l = r + 1;
}
// 最后计算结果
ll res = 1;
for(int i = 1; i <= tot; ++ i){
if(!val[i]) continue;
res = (res * val[i]) % mod;
}
return res;
}
int main()
{
//输入和初始化部分
memset(vis, 0, sizeof(vis));
scanf("%d%d",&n,&m);
init();
for(int i = 1; i <= m; ++ i){
scanf("%d%d%lld", &edge[i].u, &edge[i].v, &edge[i].c);
}
sort(edge + 1, edge + m + 1, cmp);
printf("%lld", work());
return 0;
}