HDU 4443 带环树形dp

思路:如果只有一棵树这个问题很好解决,dp一次,然后再dfs一次往下压求答案就好啦,带环的话,考虑到环上的点不是

很多,可以暴力处理出环上的信息,然后最后一次dfs往下压求答案就好啦。细节比较多。

#include<bits/stdc++.h>
#define LL long long
#define fi first
#define se second
#define mk make_pair
#define PII pair<int, int>
#define PLI pair<LL, int>
#define ull unsigned long long
using namespace std;

const int N = 1e5 + 7;
const int inf = 0x3f3f3f3f;
const LL INF = 0x3f3f3f3f3f3f3f3f;
const int mod = 1e9 + 7;
const double eps = 1e-8;

double ans[N], dp[N], up[N], down[N];
int n, m, tot, head[N], deg[N], edgecnt[N];
bool is[N], vis[N];
vector<int> cir;
struct Edge {
    int from, to, nx;
} edge[N<<1];

void addEdge(int u, int v) {
    edge[tot].from = u;
    edge[tot].to = v;
    edge[tot].nx = head[u];
    head[u] = tot++;
}
void dfs(int u) {
    vis[u] = true;
    cir.push_back(u);
    for(int i = head[u]; ~i; i = edge[i].nx) {
        int v = edge[i].to;
        if(is[v] || vis[v]) continue;
        dfs(v);
    }
}

void dfs1(int u, int fa) {
    edgecnt[u] = 0;
    for(int i = head[u]; ~i; i = edge[i].nx) {
        if(edge[i].to == fa || !is[edge[i].to]) continue;
        int v = edge[i].to;
        edgecnt[u]++;
        dfs1(v, u);
    }
    dp[u] = 1.0 / (edgecnt[u]+1);
    if(edgecnt[u]) {
        for(int i = head[u]; ~i; i = edge[i].nx) {
            if(edge[i].to == fa || !is[edge[i].to]) continue;
            int v = edge[i].to;
            dp[u] += dp[v] / edgecnt[u];
        }
    }
}

void dfs2(int u, int fa, double val) {
    if(!edgecnt[u]) {
        ans[u] += val;
        return;
    }
    double sum = 0;
    for(int i = head[u]; ~i; i = edge[i].nx) {
        if(edge[i].to == fa || !is[edge[i].to]) continue;
        sum += dp[edge[i].to];
    }
    if(!fa) {
        for(int i = head[u]; ~i; i = edge[i].nx) {
            if(edge[i].to == fa || !is[edge[i].to]) continue;
            int v = edge[i].to;
            if(edgecnt[u] == 1) dfs2(v, u, 1.0/(edgecnt[u]+2)+val);
            else dfs2(v, u, 1.0/(edgecnt[u]+2)+val/edgecnt[u]+(sum-dp[v])/(edgecnt[u]+1));
        }
    } else {
        for(int i = head[u]; ~i; i = edge[i].nx) {
            if(edge[i].to == fa || !is[edge[i].to]) continue;
            int v = edge[i].to;
            dfs2(v, u, 1.0/(edgecnt[u]+1)+val/edgecnt[u]+(sum-dp[v])/edgecnt[u]);
        }
    }
}

void init() {
    tot = 0; cir.clear();
    for(int i = 1; i <= n; i++)
        head[i]=-1, ans[i]=deg[i]=is[i]=vis[i]=0;
}

int main() {
    while(scanf("%d", &n) != EOF && n) {
        init();
        for(int i = 1; i <= n; i++) {
            int u, v; scanf("%d%d", &u, &v);
            addEdge(u, v); addEdge(v, u);
            deg[u]++, deg[v]++;
        }
        queue<int> que;
        for(int i = 1; i <= n; i++) {
            if(deg[i] == 1) {
                que.push(i);
                is[i] = true;
            }
        }
        while(!que.empty()) {
            int u = que.front(); que.pop();
            for(int i = head[u]; ~i; i = edge[i].nx) {
                int v = edge[i].to;
                if(is[v]) continue;
                deg[v]--, deg[u]--;
                if(deg[v] == 1) {
                    is[v] = true;
                    que.push(v);
                }
            }
        }

        for(int i = 1; i <= n; i++)
            if(!is[i] && !vis[i]) dfs(i);

        int cnt = cir.size();
        for(int i = 0; i < cnt; i++) {
            int root = cir[i];
            dfs1(root, 0);
            up[root] = 1.0*2/(edgecnt[root]+2); down[root] = 0;
            for(int j = head[root]; ~j; j = edge[j].nx) {
                int v = edge[j].to;
                if(!is[v]) continue;
                up[root] += dp[v]*2/(edgecnt[root]+1);
            }
        }

        for(int i = 0; i < cnt; i++) {
            double now = up[cir[i]]/2;
            for(int k = 1; k < cnt; k++) {
                int j = (i+k)%cnt;
                if(k == cnt-1) down[cir[j]] += now;
                else down[cir[j]] += now*(edgecnt[cir[j]])/(edgecnt[cir[j]]+1);
                now /= edgecnt[cir[j]]+1;
            }
            now = up[cir[i]]/2;
            for(int k = 1; k < cnt; k++) {
                int j = (i-k+cnt)%cnt;
                if(k == cnt-1) down[cir[j]] += now;
                else down[cir[j]] += now*(edgecnt[cir[j]])/(edgecnt[cir[j]]+1);
                now /= edgecnt[cir[j]]+1;
            }
        }
        for(int i = 0; i < cnt; i++)
            dfs2(cir[i], 0, down[cir[i]]);


        sort(ans+1, ans+1+n);
        reverse(ans+1, ans+1+n);
        double ret = 0;
        for(int i = 1; i <= 5; i++)
            ret += ans[i];
        printf("%.5f
", ret/n);
    }
    return 0;
}

/*
*/
原文地址:https://www.cnblogs.com/CJLHY/p/9855636.html