3312: 小奇挖矿


/*
    对于一个子节点,距离最近的要么就是父亲,要么就是父亲去的仓库 
*/ 
#include <cmath>
#include <queue>
#include <cstdio>
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>
#define ll long long
using namespace std;

const int A = 1e7+10;
const int B = 1e6+10;
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f;

inline int read() {
  char c = getchar();
  int x = 0, f = 1;
  for ( ; !isdigit(c); c = getchar()) if (c == '-') f = -1;
  for ( ; isdigit(c); c = getchar()) x = x * 10 + (c ^ 48);
  return x * f;
}

struct node{int v,nxt,w;} e[B<<1];

int n,m,cnt,head[B],w[B],f[4000][4000],vis[B],dp[B][23],dep[B];

int dis[4000][4000];

void modify(int u,int v)
{
    e[++cnt].nxt=head[u];
    e[cnt].v=v;
    head[u]=cnt;
}


void dfs1(int u,int fa)
{
    dep[u]=dep[fa]+1;
    
    for (int i=1;(1<<i)<dep[u];i++)
        dp[u][i]=dp[dp[u][i-1]][i-1];
    
    for (int i=head[u];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if (v==fa) continue;
        dp[v][0]=u;
        dfs1(v,u);
    }
     
}

int lca(int x,int y)
{
    if (dep[x]<dep[y]) swap(x,y);
    for (int i=20;i>=0;i--)
        if (dep[dp[x][i]]>=dep[y]) 
            x=dp[x][i];
    if (x==y) return x;
    
    for (int i=20;i>=0;i--)
    {
        if (dp[x][i] != dp[y][i])
        {
            x=dp[x][i];
            y=dp[y][i];
        }
    }
    
    return dp[x][0]; 
}

void dfs(int u,int fa)
{
    for (int i=1;i<=n;i++)
        f[u][i]=w[dis[u][i]];
    
    for (int i=head[u];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if (v==fa) continue;
        dfs(v,u);
        int ans=inf;
        for (int j=1;j<=n;j++)
            ans=min(ans,f[v][j]+m);
        for (int j=1;j<=n;j++)
            f[u][j]+=min(f[v][j],ans); 
    }
}

int get(int x,int y) {return dep[x]+dep[y]-2*dep[lca(x,y)];}

main()
{
    int x,y;
    n=read(),m=read();
    for (int i=1;i<n;i++) w[i]=read();
    for (int i=1;i<n;i++)
    {
        x=read(),y=read();
        modify(x,y);
        modify(y,x);
    }
    
    dfs1(1,0);
    for (int i=1;i<=n;i++)
        for (int j=1;j<=n;j++)
            dis[i][j]=get(i,j); 
            
    dfs(1,0);
    
    ll ans=inf;
    for (int i=1;i<=n;i++) ans=min(ans,(ll)f[1][i]+(ll)m);
    printf("%lld",ans);
}

原文地址:https://www.cnblogs.com/lToZvTe/p/14528697.html