CERC2017 Gambling Guide,最短路变形,期望dp

题意
给定一个无向图,你需要从1点出发到达n点,你在每一点的时候,使用1个单位的代价,随机得到相邻点的票,但是你可以选择留在原地,也可以选择使用掉这张票,
问到达n点的最小代价的方案的期望是多少。

分析

dp [i]  : 从I  到 n 需要coin  数量的期望
显然  dp[n]=0。逆序更新  (除了dp[n] ,其他的全初始化为 inf)
如果当前点为u,v为u的相邻点。
v第一次被更新,那么 dp[v]=(deg[v]-1)/deg[v]*dp[v]+1/deg[v]*dp[u]+1(+1是因为又需要一个coin)deg[v]-1 为留在v点的概率,即dp[v]=((deg[v]-1)*dp[v]+dp[u])/deg[v]+1
数学变化后:dp[v]=deg[v]+dp[u]
如果当前点为 P,v为p的相邻点
如果 dp[v]>dp[p] ,那么v再次被更新,假设为第二次更新,那么:
 dp[v]=((deg[v]-2)*dp[v]+dp[u]+dp[p])/deg[v]+1 即  dp[v]=(dp[p]+dp[u]+deg[v])/2
同理第n次更新时:dp[v]=(dp[u]+dp[p]+dp[q]+....+deg[v])/n
Used[v]:用来标记v现在是第几次被更新。可以得到:
                double tmp = dp[v]*used[v];
                used[v]++;
                dp[v] = (tmp+xp)/used[v];
#include <bits/stdc++.h>
#define ll long long 
using namespace std;
const   int N =310000;
double dp[N];
#define P pair<double,int>
const double inf  = 12000000000;
int n,m,x,y;
bool vis[N];
int used[N],deg[N];
struct Node{
    int fr,to,nex;
}e[N*2];
int  head[N],cnt;
void init()
{
    for(int i =0;i<N;i++) 
    {head[i] = -1;
     vis[i]=0;
     used[i]=0;
     deg[i]=0;
    }
     cnt = 0;
}
void add(int u,int  v)
{
    e[cnt].fr=u;e[cnt].to=v;
    e[cnt].nex=head[u];head[u]=cnt++;
}

void  solve()
{
    for(int i =1;i<n;i++) dp[i] = inf;
    
    priority_queue<P,vector<P>,greater<P> >que;//是greater 
    que.push(P(0,n));
    vis[n] =1;
    while(!que.empty()){
        P p =que.top();que.pop();
        int  u = p.second;double  xp =p.first;
        if(dp[u]<xp)  continue;
        for(int i =head[u];i+1;i=e[i].nex){
            Node  nod = e[i];
            int v=nod.to;
            if(!vis[v]){
                vis[v] = 1;
                used[v]=1;
                dp[v] = deg[v]+xp;
                que.push(P(dp[v],v));
            }
            else if(dp[v]>xp){
                double tmp = dp[v]*used[v];
                //xp+=tmp; 这样 xp 在不断变化
                used[v]++;
                //dp[v]=xp/used[v];
                dp[v] = (tmp+xp)/used[v];
                que.push(P(dp[v],v));
            }
        }
    }
    
}
int main()
{
    
    init();
    scanf("%d%d",&n,&m);
    for(int i =0;i<m;i++){
        scanf("%d%d",&x,&y);
        add(x,y);add(y,x);
        deg[x]++;deg[y]++;
    }
    
    solve();
    printf("%.12f
",dp[1]);
    return 0;
}
原文地址:https://www.cnblogs.com/tingtin/p/10726253.html