【BZOJ3572】世界树(HNOI2014)-虚树+树形DP

测试地址:世界树
做法:本题需要用到虚树+树形DP。
首先一看这道题我们就知道要用虚树,因此我们先把询问点的虚树先建出来,然后考虑DP。
我们把虚树中每个点受哪个点管辖先求出来,这是通过两次DFS来完成的,一次处理向下方向的最近,一次处理向上方向的最近。然后对于每条虚树上的边,如果边的两端所属的点不同,则表示这条边需要切断,那么我们可以倍增求出断点,在每次切断时求出较下面的那一块的大小即可。在最后别忘了求出根所属块的大小。
然后就是一大堆细节了,比如当距离相同时要比较编号之类……本蒟蒻竟然因为写反一个符号调了4h,太弱了……
以下是本人代码:

#include <bits/stdc++.h>
using namespace std;
int n,m,k,a[600010],b[300010];
int first[300010]={0},tot=0,firstv[300010]={0},totv;
int fa[300010][21],dep[300010],order[300010],tim=0;
int st[300010],top;
int down[300010],downp[300010],up[300010],upp[300010],ans[300010],siz[300010];
int belong[300010],dis[300010];
const int inf=1000000000;
bool vis[300010]={0};
struct edge
{
    int v,next,w;
}e[600010],ev[300010];

void insert(int a,int b)
{
    e[++tot].v=b,e[tot].next=first[a],first[a]=tot;
}

void insertv(int a,int b,int w)
{
    ev[++totv].v=b,ev[totv].w=w,ev[totv].next=firstv[a],firstv[a]=totv;
}

void init(int v)
{
    order[v]=++tim;
    siz[v]=1;
    for(int i=first[v];i;i=e[i].next)
        if (e[i].v!=fa[v][0])
        {
            fa[e[i].v][0]=v;
            dep[e[i].v]=dep[v]+1;
            init(e[i].v);
            siz[v]+=siz[e[i].v];
        }
}

int lca(int x,int y)
{
    if (dep[x]<dep[y]) swap(x,y);
    for(int i=20;i>=0;i--)
        if (dep[fa[x][i]]>=dep[y]) x=fa[x][i];
    if (x==y) return x;
    for(int i=20;i>=0;i--)
        if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
    return fa[x][0];
}

int findfa(int x,int y)
{
    for(int i=20;i>=0;i--)
        if ((1<<i)<=y) x=fa[x][i],y-=(1<<i);
    return x;
}

bool cmp(int a,int b)
{
    return order[a]<order[b];
}

void build()
{
    totv=0;
    sort(a+1,a+k+1,cmp);
    for(int i=1;i<k;i++)
        a[k+i]=lca(a[i],a[i+1]);
    a[k<<1]=1;
    sort(a+1,a+(k<<1)+1,cmp);
    top=0;
    for(int i=1;i<=(k<<1);i++)
        if (i==1||a[top]!=a[i])
        {
            a[++top]=a[i];
            firstv[a[top]]=0;
        }
    k=top;
    top=1;st[1]=1;
    for(int i=2;i<=k;i++)
    {
        while (top>1&&lca(st[top],a[i])!=st[top])
        {
            insertv(st[top-1],st[top],dep[st[top]]-dep[st[top-1]]);
            top--;
        }
        st[++top]=a[i];
    }
    while (top>1)
    {
        insertv(st[top-1],st[top],dep[st[top]]-dep[st[top-1]]);
        top--;
    }
}

void getdown(int v)
{
    down[v]=inf;
    for(int i=firstv[v];i;i=ev[i].next)
    {
        getdown(ev[i].v);
        if (down[v]>down[ev[i].v]+ev[i].w||(down[v]==down[ev[i].v]+ev[i].w&&downp[ev[i].v]<downp[v]))
        {
            downp[v]=downp[ev[i].v];
            down[v]=down[ev[i].v]+ev[i].w;
        }
    }
    if (vis[v]) down[v]=0,downp[v]=v;
}

void getup(int v,int lastw,int f)
{
    if (up[f]<down[f]||(down[f]==up[f]&&upp[f]<downp[f])) up[v]=up[f]+lastw,upp[v]=upp[f];
    else up[v]=down[f]+lastw,upp[v]=downp[f];
    if (vis[v]) up[v]=0,upp[v]=v;
    for(int i=firstv[v];i;i=ev[i].next)
        getup(ev[i].v,ev[i].w,v);
}

int dp(int v)
{
    int remain=siz[v];
    for(int i=firstv[v];i;i=ev[i].next)
    {
        int s=dp(ev[i].v);
        if (belong[ev[i].v]!=belong[v])
        {
            int cutlen=(dis[v]-dis[ev[i].v]+ev[i].w),cut;
            if (cutlen%2==0&&belong[v]<belong[ev[i].v]) cut=findfa(ev[i].v,cutlen/2-1);
            else cut=findfa(ev[i].v,cutlen/2);
            ans[belong[ev[i].v]]=siz[cut]-siz[ev[i].v]+s;
            remain-=siz[cut];
        }
        else remain-=siz[ev[i].v]-s;
    }
    if (v==1) ans[belong[v]]=remain;
    return remain;
}

int main()
{
    scanf("%d",&n);
    for(int i=1;i<n;i++)
    {
        int a,b;
        scanf("%d%d",&a,&b);
        insert(a,b),insert(b,a);
    }

    fa[1][0]=fa[0][0]=0;
    dep[1]=1,dep[0]=0;
    up[0]=down[0]=inf;
    init(1);
    for(int i=1;i<=20;i++)
        for(int j=1;j<=n;j++)
            fa[j][i]=fa[fa[j][i-1]][i-1];

    scanf("%d",&m);
    while(m--)
    {
        int pastk;
        scanf("%d",&k);
        pastk=k;
        for(int i=1;i<=k;i++)
        {
            scanf("%d",&a[i]);
            vis[a[i]]=1;
            ans[a[i]]=0;
            b[i]=a[i];
        }

        build();

        getdown(1);
        getup(1,0,0);
        for(int i=1;i<=k;i++)
        {
            if (up[a[i]]<down[a[i]]||(up[a[i]]==down[a[i]]&&upp[a[i]]<downp[a[i]])) belong[a[i]]=upp[a[i]],dis[a[i]]=up[a[i]];
            else belong[a[i]]=downp[a[i]],dis[a[i]]=down[a[i]];
        }
        dp(1);

        for(int i=1;i<=pastk;i++)
            printf("%d ",ans[b[i]]);
        printf("
");
        for(int i=1;i<=k;i++)
            vis[a[i]]=0;
    }

    return 0;
}
原文地址:https://www.cnblogs.com/Maxwei-wzj/p/9793517.html