虚树入门

虚树,顾名思义,就是假的树.

在树形dp中有很大的优化作用.

虚树主要针对于树中关键点的询问.我们仅仅对关键点及其lca建一棵树.这样只要保证sigmak在时间复杂度内即可.

以下是建树的模板

q=read();
for(int i=1;i<=q;++i)
{
    num=read();
    for(int j=1;j<=num;++j) b[j]=read(),vis[b[j]]=true;//标记关键点. 
    sort(b+1,b+num+1,cmp);//按照dfn排序 
    stak[top=1]=b[1];//强行加入第一个点. 
    for(int j=2;j<=num;++j)
    {
        int now=b[j];
        int lc=lca(now,stak[top]);
        while(1)
        {
            if(deep[lc]>=deep[stak[top-1]])//如果lca为top,或top-1,或在两者之间 
            {
                if(lc!=stak[top])//不等于top 
                {
                    add2(lc,stak[top]);//先连边 
                    if(lc!=stak[top-1]) stak[top]=lc;//如果在两者之间,去掉top,加入lca 
                    else --top;//否则为top-1,直接去掉top即可. 
                }
                break;
            }
            else {add2(stak[top-1],stak[top]);top--;}//lca在top-1之上,top-1向top连边,去掉top1 
        }
        stak[++top]=now;//最后把now加入栈中. 
    }
    while(--top) add2(stak[top],stak[top+1]);//最后将最右链加入加入虚树 
    dfs(stak[1]);//从最上面的点开始dfs 

这里用栈维护了虚树的最右链,dfs中记得将虚树的信息清空即可.

我觉得最难得不是虚树的建立,毕竟这就是一个模板,而是建立虚树后的dp转移...头大...

[SDOI2011]消耗战

这个题要求所有的关键点都不能到达1号点的最小代价.

看到sigma(ki)<=500000,就知道要用到虚树(要养成好习惯).

我们先考虑从普通的dp入手,再探索虚树上应该如何dp.

我们设f[i]表示以i为根的子树内的关键点都不与1联通的最小代价.

考虑当前x的状态如何转移.

首先如果x是关键点,那f[x]只能等于v(fa[x],x).也就是必须切断x的父亲与x的联系。这样x及其子树都不可能与1联通.

倘若x不是关键点,那f[x]=min(sum[x],v(fa[x],x)).sum[x]=sigmaf[y].(y=x.son)

好了,这样普通的dp就只能达到这种地步了.

如果我们把这种dp放到虚树上会是什么样呢?由于我们将许多没用的点都抽离出去了,所以如果一个点是关键带你的话,我们无法做到查询

v(fa[x],x)的值.那我们思考当想要将x的关键点拦截的话,付出他的最小代价究竟是什么,是点x到1的最小的边权.

那我们在之前的dfs中预处理出来这个东西.之后按照上面的转移即可.

#include<bits/stdc++.h>
#define ll long long
#define min(a,b) a<b?a:b
using namespace std;
const int N=500500;
int link1[N],tot1,link2[N],tot2,n,deep[N],f[N][25],q;
int b[N],num,dfn[N],stak[N],top;
ll minv[N];
bool vis[N];
struct edge{int y,next;ll v;}a1[N<<1],a2[N<<1]; 
inline int read()
{
    int x=0,ff=1;
    char ch=getchar();
    while(!isdigit(ch)) {if(ch=='-') ff=-1;ch=getchar();}
    while(isdigit(ch)) {x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
    return x*ff;
}
inline void add1(int x,int y,int v)
{
    a1[++tot1].y=y;
    a1[tot1].v=v;
    a1[tot1].next=link1[x];
    link1[x]=tot1;
}
inline void add2(int x,int y)
{
    a2[++tot2].y=y;
    a2[tot2].next=link2[x];
    link2[x]=tot2;
}
inline void dfs1(int x,int fa)
{
    dfn[x]=++num;
    for(int i=link1[x];i;i=a1[i].next)
    {
        int y=a1[i].y;
        if(y==fa) continue;
        deep[y]=deep[x]+1;
        f[y][0]=x;
        for(int j=1;j<=20;++j) f[y][j]=f[f[y][j-1]][j-1];
        minv[y]=min(minv[x],a1[i].v);
        dfs1(y,x);
    }
}
inline int lca(int a,int b)
{
    if(deep[a]>deep[b]) swap(a,b);
    for(int i=20;i>=0;--i) 
        if(deep[f[b][i]]>=deep[a]) b=f[b][i];
    if(a==b) return a;
    for(int i=20;i>=0;--i)
        if(f[a][i]!=f[b][i]) a=f[a][i],b=f[b][i];
    return f[a][0];        
}
inline bool cmp(int a,int b) {return dfn[a]<dfn[b];}
inline ll dfs2(int x)
{
    ll sum=0,dp;
    for(int i=link2[x];i;i=a2[i].next)
    {
        int y=a2[i].y;
        sum+=dfs2(y);
    }
    if(vis[x]) dp=minv[x];
    else dp=min(minv[x],sum);
    if(vis[x]) vis[x]=false;
    link2[x]=0;
    return dp;
}
int main()
{
//    freopen("1.in","r",stdin);
    n=read();
    for(int i=1;i<n;++i)
    {
        int x=read(),y=read(),v=read();
        add1(x,y,v);add1(y,x,v);
    }
    minv[1]=1e18;
    dfs1(1,0);q=read();
    while(q--)
    {
        num=read();
        for(int i=1;i<=num;++i)
        {
            b[i]=read();
            vis[b[i]]=true;
        }
        sort(b+1,b+num+1,cmp);
        stak[top=1]=b[1];
        for(int i=2;i<=num;++i)
        {
            int now=b[i];
            int lc=lca(now,stak[top]);
            while(1)
            {
                if(deep[lc]>=deep[stak[top-1]])
                {
                    if(lc!=stak[top]) 
                    {
                        add2(lc,stak[top]);
                        if(lc!=stak[top-1]) stak[top]=lc;
                        else top--;
                    }
                    break;
                }
                else {add2(stak[top-1],stak[top]);top--;}
            }
            stak[++top]=now;
        }
        while(--top) add2(stak[top],stak[top+1]);
        cout<<dfs2(stak[1])<<endl;
        tot2=0;
    }
    return 0;
}
View Code

[HEOI2014]大工程

这种题真的一搞一上午啊,还是我太菜了.....

我们看到k的范围自然就想到了虚树.

那就让我们先考虑普通的dp:

第一问,是所有关键点两两匹配的总长度之和.二三问分别是最长和最小长度.

第一问直接统计每条边的贡献,第二三问用求直径的思想。

我们设sum[x],mx[x],mn[x],size[x]分别表示以x为根的树中,所有关键点到x的路径和,最大值,最小值,和个数.

对于ans1,我们考虑当前处理到y这个儿子.

ans1+=sum[x]*size[y]+(sum[y]+dis(x,y)*size[y])*size[x].这个意思就是之前的子树中每条边都出来与y中的子树匹配.

mx,与mn就不加述说了.

我之前一直在思考如果是关键点的话,怎么特殊处理.因为我们的做法其实枚举了每一个lca,将两端拼接起来的.

可是观察上面的转移,如果我们将关键点的size[x]初始化为1,那size[x]里就为累计一下(sum[y]+dis(x,y)的代价,其实就等同于x与所有关键点的匹配.

在普通树里,dis(x,y)是1,而在虚树里dis(x,y)是deep[y]-deep[x]。之后将其转移即可.

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1000010;
int n,q,link1[N],tot1,link2[N],tot2,deep[N],f[N][25],b[N],num,dfn[N];
int stak[N],top;
ll ans1,ans2,ans3,sum[N],mx[N],mn[N],size[N];
bool vis[N];
struct edge{int y,next;}a1[N<<1],a2[N<<1]; 
inline int read()
{
    int x=0,ff=1;
    char ch=getchar();
    while(!isdigit(ch)) {if(ch=='-') ff=-1;ch=getchar();}
    while(isdigit(ch)) {x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
    return x*ff;
}
inline bool cmp(int x,int y){return dfn[x]<dfn[y];}
inline void add1(int x,int y)
{
    a1[++tot1].y=y;
    a1[tot1].next=link1[x];
    link1[x]=tot1;
}
inline void add2(int x,int y)
{
    a2[++tot2].y=y;
    a2[tot2].next=link2[x];
    link2[x]=tot2;
}
inline void dfs1(int x)
{
    dfn[x]=++num;
    for(int i=link1[x];i;i=a1[i].next)
    {
        int y=a1[i].y;
        if(y==f[x][0]) continue;
        deep[y]=deep[x]+1;
        f[y][0]=x;
        for(int j=1;j<=20;++j) f[y][j]=f[f[y][j-1]][j-1];
        dfs1(y);
    }
}
inline int lca(int a,int b)
{
    if(deep[a]>=deep[b]) swap(a,b);
    for(int i=20;i>=0;--i)
        if(deep[f[b][i]]>=deep[a]) b=f[b][i];
    if(a==b) return a;
    for(int i=20;i>=0;--i)
        if(f[a][i]!=f[b][i]) a=f[a][i],b=f[b][i];
    return f[a][0];        
}
inline void dfs2(int x)
{
    sum[x]=0;mx[x]=0;mn[x]=(vis[x]?0:1e18);size[x]=(vis[x]?1:0);
    for(int i=link2[x];i;i=a2[i].next)
    {
        int y=a2[i].y;
        dfs2(y);
        ll dis=deep[y]-deep[x];
        ans1+=(sum[y]+dis*size[y])*size[x]+sum[x]*size[y];
        ans2=max(ans2,mx[x]+mx[y]+dis);
        ans3=min(ans3,mn[x]+mn[y]+dis);
        sum[x]+=sum[y]+dis*size[y];
        mx[x]=max(mx[x],mx[y]+dis);
        mn[x]=min(mn[x],mn[y]+dis);
        size[x]+=size[y];
    }
    if(vis[x]) vis[x]=false;
    link2[x]=0;
}
int main()
{
    freopen("1.in","r",stdin);
    n=read();
    for(int i=1;i<n;++i)
    {
        int x=read(),y=read();
        add1(x,y);add1(y,x);
    }
    deep[1]=1;dfs1(1);
    q=read();
    for(int i=1;i<=q;++i)
    {
        num=read();
        for(int j=1;j<=num;++j) b[j]=read(),vis[b[j]]=true;
        sort(b+1,b+num+1,cmp);
        stak[top=1]=b[1];
        for(int j=2;j<=num;++j)
        {
            int now=b[j];
            int lc=lca(now,stak[top]);
            while(1)
            {
                if(deep[lc]>=deep[stak[top-1]])
                {
                    if(lc!=stak[top])
                    {
                        add2(lc,stak[top]);
                        if(lc!=stak[top-1]) stak[top]=lc;
                        else --top;
                    }
                    break;
                }
                else {add2(stak[top-1],stak[top]);top--;}
            }
            stak[++top]=now;
        }
        while(--top) add2(stak[top],stak[top+1]);
        ans1=0;ans2=0;ans3=1e18;
        dfs2(stak[1]);
        printf("%lld %lld %lld
",ans1,ans3,ans2);
        tot2=0; 
    }
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/gcfer/p/12491427.html