1036 商务旅行

1036 商务旅行

 

 时间限制: 1 s
 空间限制: 128000 KB
 题目等级 : 钻石 Diamond
 
 
题目描述 Description

某首都城市的商人要经常到各城镇去做生意,他们按自己的路线去做,目的是为了更好的节约时间。

假设有N个城镇,首都编号为1,商人从首都出发,其他各城镇之间都有道路连接,任意两个城镇之间如果有直连道路,在他们之间行驶需要花费单位时间。该国公路网络发达,从首都出发能到达任意一个城镇,并且公路网络不会存在环。

你的任务是帮助该商人计算一下他的最短旅行时间。

输入描述 Input Description

输入文件中的第一行有一个整数N,1<=n<=30 000,为城镇的数目。下面N-1行,每行由两个整数a 和b (1<=ab<=n; a<>b)组成,表示城镇a和城镇b有公路连接。在第N+1行为一个整数M,下面的M行,每行有该商人需要顺次经过的各城镇编号。

输出描述 Output Description

    在输出文件中输出该商人旅行的最短时间。

样例输入 Sample Input
5
1 2
1 5
3 5
4 5
4
1
3
2
5
样例输出 Sample Output

7

数据范围及提示 Data Size & Hint
 

分类标签 Tags 点此展开 

说白了就是:mst+lca+bfs(捎带个并查集)

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

枯燥无味的直接发代码的话,我自己都看不下去,所以我决定讲讲做法,实在做不出来的,再抄我下面的代码吧;

首先我们应该先知道怎么解,路线如下:

1-->3-->2-->5    而n可以达到30000,如果每个点都用spfa求最短距离然后再累加的话,肯定超时

因为图中没有存在环,所以肯定两点之间只存在一条路线,用lca做就妥妥过了(如果可以也能用线段树过,会的来教一下我

也就是说,只要求出1->1(求1->1是怕数据有开始第一个点不是到达1的),1->3,3->2,2->5的最近公共祖先,然后求每对顶点都最近公共祖先的距离和即可算出;

如果听不大懂没关系,我下面稍微模拟一下样例吧

 

    1

      /      

       2          5

                    /    

                     3      4

 

图在上面

1.先用bfs算一次顶点1到各个顶点的距离,用dis数组表示:

        1  2  3  4  5

dis    0  1  2  2  1

  求这个距离的用处,例如:2和5的最近公共祖先为1,而2到5的距离就是dis[2]-dis[1]+dis[5]-dis[1]=(dis[2]+dis[5])-2*dis[1]=2,就是2到公共祖先的距离加上5到公共祖先的距离

 

2.然后lca的做法自己百度吧,算出1--1,1--3,3--2,2--5的公共祖先,刚好都是1(,,,,)

然后算出(dis[1]+dis[1])-2*dis[1]+(dis[1]+dis[3])-2*dis[1]+(dis[3]+dis[2])-2*dis[1]+(dis[2]+dis[5])-2*dis[1] 就是答案了

代码 

#include<cstdio>
#include<iostream>
#include<vector>
#include<queue>
using namespace std;
#define N 101000
vector<int>a[N];
vector<int>t[N];
queue<int>que;
int dis[N],fa[N],n,m;
bool vis[N]={0};
struct node{
    int u,v,w;
}b[N];
int find(int x){
    return fa[x]==x?x:fa[x]=find(fa[x]);
}
void bfs(){
    que.push(1);
    vis[1]=true;
    dis[1]=0;
    int from,to,len;
    while(!que.empty()){
        from=que.front();
        que.pop();
        len=a[from].size();
        for(int i=0;i<len;i++){
            to=a[from][i];
            if(!vis[to]){
                dis[to]=dis[from]+1;
                vis[to]=true;
                que.push(to);
            }
        }
    }
    
}
void lca(int x,int father){
    int len=a[x].size(),to,num;
    for(int i=0;i<len;i++){
        to=a[x][i];
        if(to==father) continue;
        lca(to,x);
        fa[to]=x;
        vis[to]=true;
    }
    len=t[x].size();int xx,yy;
    for(int i=0;i<len;i++){
        num=t[x][i];
        xx=b[num].u;
        yy=b[num].v;
        if(xx==yy){
            b[num].w=xx;
        }
        if(xx!=x&&yy==x&&vis[xx]){
            b[num].w=find(xx);
        }
        if(xx==x&&yy!=x&&vis[yy]){
            b[num].w=find(yy);
        }
    }
}
int clac(){
    int x,y,z,sum=0;
    for(int i=1;i<=m;i++){
        x=b[i].u;
        y=b[i].v;
        z=b[i].w;
        sum+=dis[x]+dis[y]-2*dis[z];
    }
    return sum;
}
int main(){
    scanf("%d",&n);
    for(int i=1,x,y;i<n;i++){
        fa[i]=i;
        scanf("%d%d",&x,&y);
        a[x].push_back(y);
        a[y].push_back(x);
    }
    fa[n]=n;
    scanf("%d",&m);
    b[1].u=1;
    for(int i=1,x;i<=m;i++){
        scanf("%d",&x);
        b[i].v=x;
        t[x].push_back(i);
        if(i+1<=m){
            b[i+1].u=x;
            t[x].push_back(i+1);
        }
    }
    bfs();
    lca(1,1);
    printf("%d
",clac());
    return 0;
}

 倍增版lca

#include<cstdio>
#include<vector>
using namespace std;
#define N 30010
int n,deep[N],g[N][25];
vector<int>p[N];
void dfs(int x,int de){
    for(int i=0;i<p[x].size();i++){
        if(!deep[p[x][i]]){
            deep[p[x][i]]=deep[x]+1;
            g[p[x][i]][0]=x;
            dfs(p[x][i],de+1);
        }
    }
}
int lca(int a,int b){
    if(deep[a]<deep[b]) swap(a,b);
    int t=deep[a]-deep[b];
    for(int i=0;i<=20;i++){
        if((1<<i)&t){
            a=g[a][i];
        }
    }
    if(a==b) return a;
    for(int i=20;i>=0;i--){
        if(g[a][i]!=g[b][i]){
            a=g[a][i];
            b=g[b][i];
        }
    }
    return g[a][0];
}
int main(){
    scanf("%d",&n);
    for(int i=1,x,y;i<n;i++){
        scanf("%d%d",&x,&y);
        p[x].push_back(y);
        p[y].push_back(x);
    }
    dfs(1,1);
    for(int j=1;j<=20;j++){
        for(int i=1;i<=n;i++){
            g[i][j]=g[g[i][j-1]][j-1];
        }
    }
    int x,y,q,ans=0;
    scanf("%d%d",&q,&x);
    for(int i=1;i<q;i++){
        scanf("%d",&y);
        int gg=lca(x,y);
        ans+=deep[x]+deep[y]-2*deep[gg];
        x=y;
    }
    printf("%d
",ans);
    return 0;
}
原文地址:https://www.cnblogs.com/shenben/p/5558992.html