[BZOJ 3991][SDOI2015]寻宝游戏(dfs序)

题面

小B最近正在玩一个寻宝游戏,这个游戏的地图中有N个村庄和N-1条道路,并且任何两个村庄之间有且仅有一条路径可达。游戏开始时,玩家可以任意选择一个村庄,瞬间转移到这个村庄,然后可以任意在地图的道路上行走,若走到某个村庄中有宝物,则视为找到该村庄内的宝物,直到找到所有宝物并返回到最初转移到的村庄为止。

小B希望评测一下这个游戏的难度,因此他需要知道玩家找到所有宝物需要行走的最短路程。但是这个游戏中宝物经常变化,有时某个村庄中会突然出现宝物,有时某个村庄内的宝物会突然消失,因此小B需要不断地更新数据,但是小B太懒了,不愿意自己计算,因此他向你求助。为了简化问题,我们认为最开始时所有村庄内均没有宝物

分析

本质上是在一棵树上取出若干节点,询问把这几个节点访问一遍的距离

可以发现如果我们按照dfs序将节点排序,然后将排序后的相邻节点距离相加,最后再加上序列首尾距离,就能求出答案

如序列为{1,3,4,5},则答案为dist(1,3)+dist(3,4)+dist(4,5)+dist(5,1)

因为我们访问节点一定是像dfs一样访问才能得到最短路径,所以正确性显然

我们用一个set维护这个排序后的节点序列

可以发现,每次新加入一个节点之后只会改变节点的前驱和后继相关的距离,维护一下即可

代码

//https://www.luogu.org/problemnew/show/P3320
#include<iostream>
#include<cstdio>
#include<set>
#include<cmath>
#define INF 0x3f3f3f3f
#define maxn 100005
#define maxlogn 20
using namespace std;
int n,m;
struct edge{
    int from;
    int to;
    int next;
    int len;
}E[maxn<<1];
int head[maxn];
int sz=1;
void add_edge(int u,int v,int w){
    sz++;
    E[sz].from=u;
    E[sz].to=v;
    E[sz].next=head[u];
    E[sz].len=w;
    head[u]=sz;
}
int logn;
int cnt=0;
int dfn[maxn];
int hash_dfn[maxn];//存储dfs序为i的节点编号
int anc[maxn][maxlogn];
int deep[maxn];
long long dist[maxn];
void dfs(int x,int fa){
    dfn[x]=++cnt;
    hash_dfn[dfn[x]]=x;
    deep[x]=deep[fa]+1;
    anc[x][0]=fa;
    for(int i=1;i<=logn;i++) anc[x][i]=anc[anc[x][i-1]][i-1];
    for(int i=head[x];i;i=E[i].next){
        int y=E[i].to;
        if(y!=fa){
            dist[y]=dist[x]+E[i].len;
            dfs(y,x);
        }
    } 
}
int lca(int x,int y){
    if(deep[x]<deep[y]) swap(x,y);
    for(int i=logn;i>=0;i--){
        if(deep[anc[x][i]]>=deep[y]){
            x=anc[x][i];
        }
    } 
    if(x==y) return x;
    for(int i=logn;i>=0;i--){
        if(anc[x][i]!=anc[y][i]){
            x=anc[x][i];
            y=anc[y][i];
        }
    }
    return anc[x][0];
}
long long get_dist(int x,int y){
    return dist[x]+dist[y]-2*dist[lca(x,y)]; 
}

set<int>seq; //set里实际上存的是节点的dfs序值而不是编号
long long sum=0,ans;
int main(){
    int u,v,w,x;
    scanf("%d %d",&n,&m);
    logn=log2(n);
    for(int i=1;i<n;i++){
        scanf("%d %d %d",&u,&v,&w);
        add_edge(u,v,w);
        add_edge(v,u,w);
    }
    dfs(1,0);
    seq.insert(INF);//防止越界
    seq.insert(-INF);
    for(int i=1;i<=m;i++){
        scanf("%d",&x);
        if(!seq.count(dfn[x])){
            int l=*(--seq.lower_bound(dfn[x]));
            int r=*seq.upper_bound(dfn[x]);
            //注意边界判断
            if(l!=-INF) sum+=get_dist(hash_dfn[l],x);
            if(r!=INF) sum+=get_dist(hash_dfn[r],x);
            if(l!=-INF&&r!=INF) sum-=get_dist(hash_dfn[l],hash_dfn[r]);
            seq.insert(dfn[x]);
        }else{
            int l=*(--seq.lower_bound(dfn[x]));
            int r=*seq.upper_bound(dfn[x]);
            if(l!=-INF) sum-=get_dist(hash_dfn[l],x);
            if(r!=INF) sum-=get_dist(hash_dfn[r],x);
            if(l!=-INF&&r!=INF) sum+=get_dist(hash_dfn[l],hash_dfn[r]);
            seq.erase(seq.lower_bound(dfn[x]));
        }
        ans=sum;
        if(seq.size()-2>=2) ans+=get_dist(hash_dfn[*(++seq.lower_bound(-INF))],hash_dfn[*(--seq.lower_bound(INF))]);
        //记得加上首尾距离,注意我们一开始插入了-INF和+INF,所以set的大小>=4时才会有序列首尾
        printf("%lld
",ans);
    }
}

原文地址:https://www.cnblogs.com/birchtree/p/10822406.html