BZOJ3991:寻宝游戏 (LCA+dfs序+树链求并+set)

B最近正在玩一个寻宝游戏,这个游戏的地图中有N个村庄和N-1条道路,并且任何两个村庄之间有且仅有一条路径可达。游戏开始时,玩家可以任意选择一个村庄,瞬间转移到这个村庄,然后可以任意在地图的道路上行走,若走到某个村庄中有宝物,则视为找到该村庄内的宝物,直到找到所有宝物并返回到最初转移到的村庄为止。小B希望评测一下这个游戏的难度,因此他需要知道玩家找到所有宝物需要行走的最短路程。但是这个游戏中宝物经常变化,有时某个村庄中会突然出现宝物,有时某个村庄内的宝物会突然消失,因此小B需要不断地更新数据,但是小B太懒了,不愿意自己计算,因此他向你求助。为了简化问题,我们认为最开始时所有村庄内均没有宝物

 

Input

 第一行,两个整数N、M,其中M为宝物的变动次数。

接下来的N-1行,每行三个整数x、y、z,表示村庄x、y之间有一条长度为z的道路。
接下来的M行,每行一个整数t,表示一个宝物变动的操作。若该操作前村庄t内没有宝物,则操作后村庄内有宝物;若该操作前村庄t内有宝物,则操作后村庄内没有宝物。
 
Output

 M行,每行一个整数,其中第i行的整数表示第i次操作之后玩家找到所有宝物需要行走的最短路程。若只有一个村庄内有宝物,或者所有村庄内都没有宝物,则输出0。

 
Sample Input
4 5
1 2 30
2 3 50
2 4 60
2
3
4
2
1

Sample Output

0

100

220

220

280

Hint :1<=N<=100000,1<=M<=100000,对于全部的数据,1<=z<=10^9

思路:把所有有宝藏的地方连起来变成一棵树,结果就是这棵树的权值*2;

我们已知做法:按dfs序排序,结果是所有点的dis之和,去重,需要减去相邻点LCA的dis。

那么容易用set得到前缀和后缀。 这里是假设以1为根,最后减去虚树的根的距离。

倍增LCA版本:

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn=200010;
int Laxt[maxn],Next[maxn],To[maxn],cost[maxn],dep[maxn];
int fa[maxn][20],in[maxn],pos[maxn],vis[maxn],times,cnt;
ll dis[maxn],ans;
set<int>s;
set<int>::iterator it;
void add(int u,int v,int c){
    Next[++cnt]= Laxt[u];
    Laxt[u]=cnt; To[cnt]=v; cost[cnt]=c;
}
void dfs(int u,int f)
{
    fa[u][0]=f;in[u]=++times;pos[times]=u; dep[u]=dep[f]+1;
    for(int i=Laxt[u];i;i=Next[i]){
        if(To[i]==f) continue;
        dis[To[i]]=dis[u]+cost[i];
        dfs(To[i],u);
    }
}
int LCA(int u,int v)
{
    if(dep[u]<dep[v]) swap(u,v); 
    for(int i=19;i>=0;i--) if(dep[fa[u][i]]>=dep[v]) u=fa[u][i];
    if(u==v) return u;
    for(int i=19;i>=0;i--) if(fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i];
    return fa[u][0];
}
void addset(int x)
{
    s.insert(in[x]);
    it=s.lower_bound(in[x]);
    ans+=dis[x];
    int a=0,b=0;
    if( it!=s.begin() ) a=pos[*(--it)++];
    if((++it)!=s.end() ) b=pos[*it];
    ans+=dis[LCA(a,b) ];
    ans-=dis[LCA(a,x) ] + dis[LCA(b,x) ];
}
void delset(int x)
{
    ans-=dis[x];
    it=s.find(in[x]);
    int a=0,b=0;
    if(it!=s.begin()) a=pos[*(--it)++];
    if((++it)!=s.end()) b=pos[*it];
    ans+=dis[LCA(a,x) ]+dis[LCA(b,x) ];
    ans-=dis[LCA(a,b) ];
    s.erase(in[x]);
}
int main()
{
    int N,M,u,v,c,i,j;
    scanf("%d%d",&N,&M);
    for(i=1;i<N;i++){
        scanf("%d%d%d",&u,&v,&c);
        add(u,v,c); add(v,u,c);
    }
    dfs(1,0);
    for(i=1;i<=19;i++)
     for(j=1;j<=N;j++)
      fa[j][i]=fa[fa[j][i-1]][i-1];
    for(i=1;i<=M;i++){
        scanf("%d",&u);
        if(!vis[u]) vis[u]=1,addset(u);
        else vis[u]=0,delset(u);
        it=s.end();
        int rt=LCA(pos[*s.begin()],pos[*(--it)]);
        //cout<<rt<<" "<<pos[*s.begin()]<<" "<<pos[*(--it)]<<" "<<dis[rt]<<" ";
        printf("%lld
",(ans-dis[rt])*2);    
    }
    return 0;
}

ST表LCA(直接得到距离)版本。

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn=200010;
int Laxt[maxn],Next[maxn],To[maxn],cost[maxn];
int in[maxn],vis[maxn],times,cnt,lg2[maxn]; 
ll st[maxn][20],dis[maxn],ans;
set<int>s;
set<int>::iterator it,pre,lat;
void add(int u,int v,int c){
    Next[++cnt]= Laxt[u];
    Laxt[u]=cnt; To[cnt]=v; cost[cnt]=c;
}
void dfs(int u,int f)
{
    in[u]=++times; st[times][0]=dis[u]; 
    for(int i=Laxt[u];i;i=Next[i]){
        if(To[i]==f) continue;
        dis[To[i]]=dis[u]+cost[i];
        dfs(To[i],u);
        st[++times][0]=dis[u];
    }
}
ll LCA(int x,int y)
{
    int t=lg2[y-x+1];
    return min(st[x][t],st[y-(1<<t)+1][t]);
}
void addset(int x,ll t)
{
    if(t==1) s.insert(x);    
    it=s.find(x); lat=s.upper_bound(x); 
    if(it!=s.begin()) pre=--s.lower_bound(x),ans-=LCA(*pre,*it)*t;
    if(lat!=s.end())ans-=LCA(*it,*lat)*t;
    if(it!=s.begin()&&lat!=s.end()) ans+=LCA(*pre,*lat)*t;
    if(t==-1) s.erase(x);
}
int main()
{
    int N,M,u,v,c,i,j;
    scanf("%d%d",&N,&M);
    for(i=1;i<N;i++){
        scanf("%d%d%d",&u,&v,&c);
        add(u,v,c); add(v,u,c);
    }
    dfs(1,0);
    for(lg2[1]=0,i=2;i<=times;++i) lg2[i]=lg2[i>>1]+1;
    for(i=1;i<=lg2[times];i++)
     for(j=1;j+(1<<i)-1<=times;j++)
      st[j][i]=min(st[j][i-1],st[j+(1<<i-1)][i-1]);
    for(i=1;i<=M;i++){
    scanf("%d",&u);
        if(!vis[u]) ans+=dis[u],addset(in[u],1);
        else ans-=dis[u],addset(in[u],-1);
        vis[u]^=1;
        ll rt=LCA(*s.begin(),*(--s.end()));
        printf("%lld
",(ans-rt)*2);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/hua-dong/p/9256375.html