【NOIP2018】保卫王国 题解(树形DP+倍增)

题目大意:给定一棵含有$n$个结点的树,每个结点有权值$p_i$。要求驻扎军队,一条边连接的两结点必须至少有一个驻扎军队。现在有$q$次询问,每次规定两个点$a,b$,分别要求它们必须驻扎/不驻扎$(0/1)$。问每次驻扎的最小费用。$n,qleq 10^5$

------------------------

如果没有询问,那就是没有上司的舞会。设$f_{i,0/1}$表示$i$不选/选,以$i$为根的子树所花费的最小代价。有转移:

$f_{i,0}=sum f_{j,1}$

$f_{i,1}=sum min(f_{j,0},f_{j,1})+p_i$

现在若有询问,我们先考虑暴力的做法,就是将它所要求的地方设成$inf/-inf$,然后每次都跑一遍树形DP。这样的复杂度是$O(nq)$的,能得到44pts。然而这样会产生很多冗余状态:发现强制要求$a,b$改变状态只会影响到$a-lca-b$这一条链。所以我们不妨考虑倍增,预处理出$f$,每次只处理$a$到$b$这条链。

设$g_{i,0/1}$表示整棵树去掉以$i$为根的子树,$i$不选/选的最小代价。有转移:

$g_{v,0}=g_{x,1}+f_{x,1}-min(f_{v,0},f_{v,1})$

$g_{v,1}=min(g_{x,0}+f_{x,0}-f_{v,1},g_{v,0})$

令$anc$表示$i$的$2^j$祖先,设$ff_{i,j,0/1,0/1}$表示$anc-i$上路径,$i$不选/选,$anc$不选/选的最小代价。通过枚举$2^{j-1}$祖先的状态进行转移。注意边界处理。

这样我们求出了$f,g,ff$,可以着手对询问的处理了。发现若$a$是$b$的祖先,那么直接倍增上去即可,最后加上$g_{a,x}$;若不为祖先-子孙关系,那么就都先倍增到$lca$的儿子处,然后枚举$lca$和儿子的两个状态取最小值即可。

时间复杂度$O((n+q)log n)$。注意开$long long$

代码:

#include<iostream>
#include<cstdio>
#include<cstring>
#define int long long
using namespace std;
const int N=100005;
const int inf=1e18;
char id[10];
int f[N][2],g[N][2],fa[N][21],ff[N][21][2][2],v[N],dep[N],n,m;
int head[N],cnt;
struct node
{
    int next,to;
}edge[N*2];
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if (ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
inline void add(int from,int to)
{
    edge[++cnt]=(node){head[from],to};
    head[from]=cnt;
}
inline void dfs1(int now,int father)
{
    fa[now][0]=father;dep[now]=dep[father]+1;
    f[now][1]=v[now];
    for (int i=head[now];i;i=edge[i].next)
    {
        int to=edge[i].to;
        if (to==father) continue;
        dfs1(to,now);
        f[now][0]+=f[to][1];
        f[now][1]+=min(f[to][0],f[to][1]);
    }
}
inline void dfs2(int now,int fa)
{
    for (int i=head[now];i;i=edge[i].next)
    {
        int to=edge[i].to;
        if (to==fa) continue;
        g[to][0]=g[now][1]+f[now][1]-min(f[to][0],f[to][1]);
        g[to][1]=min(g[to][0],g[now][0]+f[now][0]-f[to][1]);
        dfs2(to,now);
    }
}
inline int solve(int x,int a,int y,int b)
{
    if (dep[x]<dep[y]) swap(x,y),swap(a,b);
    int tx[2]={inf,inf},ty[2]={inf,inf};
    int nx[2],ny[2];
    tx[a]=f[x][a];ty[b]=f[y][b];
    for (int i=19;i>=0;i--)
    {
        if (dep[fa[x][i]]>=dep[y])
        {
            nx[0]=nx[1]=inf;
            for (int j=0;j<2;j++)
                for (int k=0;k<2;k++)
                    nx[j]=min(nx[j],tx[k]+ff[x][i][k][j]);
            tx[0]=nx[0],tx[1]=nx[1],x=fa[x][i];
        }
    }
    if (x==y) return tx[b]+g[x][b];
    for (int i=19;i>=0;i--)
    {
        if (fa[x][i]!=fa[y][i])
        {
            nx[0]=nx[1]=ny[0]=ny[1]=inf;
            for (int j=0;j<2;j++)
                for (int k=0;k<2;k++)
                    nx[j]=min(nx[j],tx[k]+ff[x][i][k][j]),
                    ny[j]=min(ny[j],ty[k]+ff[y][i][k][j]);
            tx[0]=nx[0],tx[1]=nx[1],x=fa[x][i];
            ty[0]=ny[0],ty[1]=ny[1],y=fa[y][i];
        }
    }
    int l=fa[x][0];
    int ans0=f[l][0]-f[x][1]-f[y][1]+tx[1]+ty[1]+g[l][0];
    int ans1=f[l][1]-min(f[x][0],f[x][1])-min(f[y][0],f[y][1])+min(tx[0],tx[1])+min(ty[0],ty[1])+g[l][1];
    return min(ans0,ans1);
}
signed main()
{
    n=read();m=read();scanf("%s",id);
    for (int i=1;i<=n;i++) v[i]=read();
    for (int i=1;i<n;i++)
    {
        int x=read(),y=read();
        add(x,y);add(y,x);
    }
    dfs1(1,0);
    dfs2(1,0);
    for (int i=1;i<=n;i++)
    {
        ff[i][0][0][0]=inf;
        ff[i][0][0][1]=f[fa[i][0]][1]-min(f[i][0],f[i][1]);
        ff[i][0][1][0]=f[fa[i][0]][0]-f[i][1];
        ff[i][0][1][1]=f[fa[i][0]][1]-min(f[i][0],f[i][1]);
    }
    for (int j=1;j<=19;j++)
        for (int i=1;i<=n;i++)
        {
            int tmp=fa[i][j-1];
            fa[i][j]=fa[tmp][j-1];
            for (int u=0;u<2;u++)
                for (int v=0;v<2;v++)
                {
                    ff[i][j][u][v]=inf;
                    for (int w=0;w<2;w++)
                        ff[i][j][u][v]=min(ff[i][j][u][v],ff[i][j-1][u][w]+ff[tmp][j-1][w][v]);
                }
        }
    while(m--)
    {
        int a=read(),x=read(),b=read(),y=read();
        if (!x&&!y&&(fa[b][0]==a||fa[a][0]==b)){
            printf("-1
");
            continue;
        }
        printf("%lld
",solve(a,x,b,y));
    }
    return 0;
}
原文地址:https://www.cnblogs.com/Invictus-Ocean/p/13769275.html