点集直径

MMSet2
给定一棵n个节点的树,点编号为1…n。
Q次询问,每次询问给定一个点集S,令,(f(u)=max_{vin S}dist(u,v))
你需要求出(min_{u=1dots n}f(u))
其中dist(u,v)表示树上路径(u,v)的边数。
输入描述:
第一行一个整数n,接下来n−1行每行两个整数表示树上的一条边。
接下来一行一个整数Q,接着Q行,每行第一个数是|S|,剩下|S|个互不相同的数代表这个集合。
输出描述:
输出Q行,每行一个整数表示答案。
示例1
输入

3
1 2
1 3
1
2 2 3

输出

1

备注:
n≤3×105,|S|≥1,∑|S|≤106
每条边的长度是1,显然答案ans就是S的若干条直径中两条半径差值最小时较大的那条半径,因为如果有u'点到S的其他点的距离均小于刚刚的答案,那么直径就可以减小了,当然,这样的半径也能成为f(u),因为如果f(u)能扩大,则直径就能增大了。
由于边权均为1,假设两半径为(r_i,r_2)(r_1>r_2)如果(r_1,r_2)相差d>=2,那么可以将中心点左移或右移一个点,使d-1,所以,d<=1;
ans=直径/2(向上取整)
再来考虑如何计算S的直径:
设D(tree)=(a,b)表示tree的直径是a,b间的距离,那么(D(treecup x)=max(dist(a,b),dist(a,x),dis(b,x)))
证明:
设有点c使得dist(c,x)大于直径。那么a或b在c->x的路径上,因为如果不在,
则x->c->a(或b)更大,所以a->b->c长度大于直径,矛盾。
伪代码:
1.刚开始有一个点a;
2.加入一个点b,如果dist (a,b)使直径变大,b记为x,
3.重复2(期间a不变)直到加完。
这样我们就计算了所有dist(a,x),但是没计算dist(b,x);
最后再循环一次,计算所有dist((b_i),x),遇到更大的就更新直径。

#include<bits/stdc++.h>
using namespace std;
const int MAXN=3e5+8;
struct E{int y,nt;}e[MAXN<<1];
int head[MAXN],cnt;
void add(int x,int y){
    e[++cnt].nt=head[x];
    e[cnt].y=y;
    head[x]=cnt;
}
int tot[MAXN],deep[MAXN],son[MAXN],fa[MAXN];
int dfs1(int now,int pre,int dep){
    tot[now]=1;
    fa[now]=pre;
    deep[now]=dep;
    int max_son=-1;
    for(int i=head[now];i;i=e[i].nt){
        int to=e[i].y;
        if(to==pre)continue;
        tot[now]+=dfs1(to,now,dep+1);
        if(tot[to]>max_son){
            max_son=tot[to];
            son[now]=to;
        }
    }
    return tot[now];
}
int top[MAXN];
void dfs2(int now,int topfa){
    top[now]=topfa;
    if(!son[now])return;
    dfs2(son[now],topfa);
    for(int i=head[now];i;i=e[i].nt){
        int to=e[i].y;
        if(!top[to])dfs2(to,to);
    }
}
int lca(int x,int y){
    while(top[x]^top[y]){
        if(deep[top[x]]<deep[top[y]])swap(x,y);
        x=fa[top[x]];
    }
    if(deep[x]<deep[y])return x;
    return y;
}
inline int dist(int x,int y){return deep[x]+deep[y]-2*deep[lca(x,y)];}
int n,s[MAXN];
int main() {
    scanf("%d",&n);
    int x,y;
    for(int i=1; i<n; ++i) {
        scanf("%d%d",&x,&y);
        add(x,y);add(y,x);
    }
    dfs1(1,0,1);
    dfs2(1,1);
    int q;
    scanf("%d",&q);
    while(q--) {
        int S;
        scanf("%d",&S);
        int d=-1,p=0;
        for(int i=0;i<S;++i){
            scanf("%d",s+i);
            int tmp=dist(s[0],s[i]);
            if(tmp>d){d=tmp,p=i;}
        }
        for(int i=0;i<S;++i){
            d=max(d,dist(s[p],s[i]));
        }
        printf("%d
",(d+1)/2);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/foursmonth/p/14155900.html