Codeforces 418d Big Problems for Organizers [树形dp][倍增lca]

题意:

给你一棵有n个节点的树,树的边权都是1.

有m次询问,每次询问输出树上所有节点离其较近结点距离的最大值。

思路:

1.首先是按照常规树形dp的思路维护一个子树节点中距离该点的最大值son_dis[i],维护非子树节点中距离该点的最大值fa_dis[i];

2.对于每个节点维护它最大的三个儿子节点的son_dis;

3.维护up[i][j]和down[i][j]数组,这个类似倍增lca里边的fa[i][j],up[i][j]代表的含义是从第j个点向上到它的第2^i个父节点这条链上的点除了该节点所在子树外的距离的最大值。down[i][j]同理,但是维护的是从第2^i父节点到该点的链上除了该节点所在子树外的距离的最大值。在这里尤其注意的是,采取了类似差分的思想。看巨巨代码的时候我想了好一会。到这里预处理完毕。

4.对于给定的两个节点。假设a为深度较深的,b为深度浅的。

对于节点a,a到a的子树中所有的点肯定较近,所以son_dis[a]有可能是答案。a到a和b的中点的那条链上距离的最大值也有可能是答案。

对于b

假设b不是公共祖先,那么son_dis[b]有可能是答案。b到中点的链上的距离的最大值也有可能是答案。

若b是公共祖先,那么只有b到中点的链上的距离的最大值也有可能是答案。

对于最近公共祖先r

r的不包含a和b的子树的dis_son有可能是答案,r的fa_dis[r]有可能是答案。

最终结果是在有可能的答案中找最大值。

代码越改越挫。

#include<bits/stdc++.h>
#define MAXN 100050
#define MAXM 200050
using namespace std;
const int inf=0x3f3f3f3f;
struct st{
    int num,id;
};
bool operator < (const st &a,const st &b){
    return a.num>b.num;
}
multiset<st>my_set[MAXN];
struct edge{
    int id;
    edge *next;
};
int ednum;
edge edges[MAXM];
edge *adj[MAXN];
int dep[MAXN],son_dis[MAXN],fa_dis[MAXN],max_num[MAXN],father[MAXN],max_x[MAXN],rt[25][MAXN],siz[MAXN],up[25][MAXN],down[25][MAXN];
bool vis[MAXN];
inline void addedge(int a,int b){
    edge *tmp=&edges[ednum++];
    tmp->id=b;
    tmp->next=adj[a];
    adj[a]=tmp;
}
void dfs(int pos,int deep){
    dep[pos]=deep;
    siz[pos]=1;
    int mmax=-1;
    for(edge *it=adj[pos];it;it=it->next){
        if(dep[it->id]==0){
            father[it->id]=pos;
            rt[0][it->id]=pos;
            dfs(it->id,deep+1);
            st tmp;
            tmp.id=it->id;
            tmp.num=son_dis[it->id];
            my_set[pos].insert(tmp);
            mmax=max(mmax,son_dis[it->id]);
            son_dis[pos]=max(son_dis[pos],son_dis[it->id]+1);
            siz[pos]+=siz[it->id];
        }
    }
    int num=0;
    for(edge *it=adj[pos];it;it=it->next){
        if(father[it->id]==pos&&son_dis[it->id]==mmax)num++;
    }
    max_num[pos]=num;
    max_x[pos]=mmax+1;
}
void dfs2(int pos){
    fa_dis[pos]=fa_dis[father[pos]]+1;
    if(max_num[father[pos]]>1||son_dis[pos]+1!=max_x[father[pos]]){
        fa_dis[pos]=max(fa_dis[pos],max_x[father[pos]]+1);
        up[0][pos]=max_x[father[pos]]-dep[father[pos]];
        down[0][pos]=max_x[father[pos]]+dep[father[pos]];
    }
    else{
        int maxx=-2;
        for(edge *it=adj[father[pos]];it;it=it->next){
            if(father[it->id]==father[pos]&&(it->id!=pos)){
                maxx=max(maxx,son_dis[it->id]);
            }
        }
        fa_dis[pos]=max(fa_dis[pos],maxx+2);
        if(maxx==-2)maxx=-1;
        maxx++;
        up[0][pos]=maxx-dep[father[pos]];
        down[0][pos]=maxx+dep[father[pos]];
    }
    for(edge *it=adj[pos];it;it=it->next){
        if(father[it->id]==pos)dfs2(it->id);
    }
}
void prelca(int n){
    up[0][0]=down[0][0]=-inf;
    for(int i=1;i<=20;i++){
        for(int j=1;j<=n;j++){
            rt[i][j]=rt[i-1][j]==-1?-1:rt[i-1][rt[i-1][j]];
            up[i][j]=max(up[i-1][j],up[i-1][rt[i-1][j]]);
            down[i][j]=max(down[i-1][j],down[i-1][rt[i-1][j]]);
        }
    }
}
int LCA(int u,int v){//查询u和v的lca
    if(dep[u]<dep[v])swap(u,v);
    for(int i=0;i<21;i++){
        if((dep[u]-dep[v])>>i&1){
            u=rt[i][u];
        }
    }
    if(u==v)return u;
    for(int i=19;i>=0;i--){
        if(rt[i][u]!=rt[i][v]){
            u=rt[i][u];
            v=rt[i][v];
        }
    }
    return rt[0][u];
}
int jump(int &pos,int num,int tmp[][MAXN]){//查询节点pos的第num个父亲
    int rel=-inf;
    for(int i=0;i<21;i++){
        if(num>>i&1){
            rel=max(rel,tmp[i][pos]);
            pos=rt[i][pos];
        }
    }
    return rel;
}
void solve(int a,int b){
    int r=LCA(a,b);
    if(dep[a]<dep[b])swap(a,b);
    int maxa,maxb,maxc,maxd,half,v,w,ar,br;
    maxa=maxb=maxc=maxd=0;
    ar=dep[a]-dep[r];
    br=dep[b]-dep[r];
    v=a;w=b;
    half=min((ar+br)/2,ar-1);
    maxa=max(son_dis[a],jump(v,half,up)+dep[a]);
    maxb=jump(v,ar-half-1,down)-dep[r]+br;
    maxc=-inf;
    maxd=fa_dis[r]+min(ar,br);
    if(r!=b){
        maxb=max(maxb,son_dis[b]);
        maxb=max(maxb,jump(w,br-1,up)+dep[b]);
        set<st>::const_iterator it=my_set[r].begin();
        for(int i=1;i<=min((int)my_set[r].size(),3);i++){
            if(it->id!=v&&it->id!=w){
                maxc=it->num+1+min(dep[a],dep[b])-dep[r];
                break;
            }
            it++;
        }
    }
    else{
        set<st>::const_iterator it=my_set[r].begin();
        for(int i=1;i<=min((int)my_set[r].size(),2);i++){
            if(it->id!=v){
                maxc=it->num+1;
                break;
            }
            it++;
        }
    }
    printf("%d
",max(max(maxa,maxb),max(maxc,maxd)));
}
int main(){
    int n;
    scanf("%d",&n);
    for(int i=1;i<n;i++){
        int a,b;
        scanf("%d%d",&a,&b);
        addedge(a,b);
        addedge(b,a);
    }
    int m;
    memset(rt,-1,sizeof(rt));
    dfs(1,1);
    for(edge *it=adj[1];it;it=it->next){
        if(father[it->id]==1){
            dfs2(it->id);
        }
    }
    prelca(n);
    scanf("%d",&m);
    for(int i=1;i<=m;i++){
        int a,b;
        scanf("%d%d",&a,&b);
        solve(a,b);
    }
}
原文地址:https://www.cnblogs.com/tun117/p/5432705.html