动态DP

应用

动态\(DP\)主要是解决:在树上或链上\(dp\)后,后期对树上链上的点进行修改,然后询问修改后的答案。

其经典例题:

P4719

\(n\)个点的树,给出每个点的点权,求最大权独立集。中途给出\(m\)个修改,每次修改后输出修改后的最优答案。

前置算法

我们主要考虑树上,解决这类问题,需要用到三个算法,树形\(dp\),树链剖分,矩阵乘法。

树形\(dp\)

我们先简单考虑不进行修改,那么这是一道非常简单的树形\(dp\)题,设\(f[x][1]\)为此点必选,\(f[x][0]\)为此点必不选的最大权值。

转移是,设\(v\)\(x\)子节点,\(f[x][1]=\sum f[v][0] +a[x]\)\(f[x][0]=\sum max(f[v][0],f[v][1])\)

\(1\)为根节点。答案就是\(max(f[1][0],f[1][1])\)

树链剖分

往往在树上进行多点修改,就需要用到\(dfs\)序或者树链剖分,前者主要是处理子树,后者是处理链,树链剖分里面分为重链和轻链,因为轻链连接的子树大小是小于等于其父亲的子树的大小的一半,所以轻链只有\(log\)个,那么在上跳的时候,也就只会跳\(log\)次。

矩阵乘法

矩阵乘法是把一个转移式变成矩阵,然后用矩阵乘法来计算,就不用一个一个地转移,而是很快的乘出总的转移式,在修改某处的\(dp\)值后,普通做法是暴力更新一遍,而矩阵乘法可以结合线段树,单点更新后按照线段树操作合并即可,大大优化了复杂度。

动态\(dp\)

讲到动态\(dp\)了,动态\(dp\)的大致思路是 : 把重链的转移与轻链分离,重链上的每个点都有个转移矩阵,转移矩阵的值由与这个点相连的轻链上的值决定,每次修改一个点,就依次沿着重链上跳,每次跳到一个新的重链上时,就根据刚刚跳过来的轻链上的值,修改这个点的转移矩阵,继续上跳。

具体做法:我们设\(g[x][0]\)表示\(x\)点必不选,排除重儿子的情况下,的最大收益,\(g[x][1]\)则表示\(x\)点必选的最大收益。

转移和前面树形\(dp\)类似,我们原来的\(f\)转移就变成了:(设\(v\)为重儿子)

\(f[x][0]=g[x][0]+max(f[v][0],f[v][1]).\) \(f[x][1]=g[x][1]+f[v][0]\)

我们对这个式子变一下:

\(f[x][0]=max(g[x][0]+f[v][0],g[x][0]+f[v][1])\) \(f[x][1]=(-inf+f[v][1],g[x][1]+f[v][0])\)

我们发现,这个和矩阵转移式很像,但是\(+\)法变成了取\(max\)

感性手推了一下,发现也满足结合律......感性理解感性理解

所以我们定义一个新的矩阵乘法,由\(\sum a_{ik}+b_{kj}\)改为\(max( a_{ik}+b_{kj})\)

那么可得:

\[\begin{bmatrix}f_{x1} \\ f_{x0} \end{bmatrix}=\begin{bmatrix} -inf & g_{x1} \\ g_{x0}&g_{x0}\end{bmatrix}\times \begin{bmatrix}f_{v1}\\f_{v0}\end{bmatrix} \]

这个就是转移矩阵了:

\[\begin{bmatrix} -inf & g_{x1}\\g_{x0}&g_{x0}\end{bmatrix} \]

每次修改一个点的值后,相应修改其转移矩阵,然后跳到重链顶端,求出重链顶端的\(f\)值,由这个\(f\)去更新其父亲点的\(g\),然后同时更新其夫妻点的转移矩阵。然后继续上跳即可。

贴一下代码:

#include <bits/stdc++.h>
using namespace std;

int n,m;
const int MAXN=1e5+5;

struct mat{
    int n,m;
    int w[3][3];
};

mat operator * (const mat &aa,const mat &bb){
    mat nw;
    nw.n=aa.n,nw.m=bb.m;
    for(int i=1;i<=nw.n;i++){
        for(int j=1;j<=nw.m;j++){
            nw.w[i][j]=-100000000;
            for(int k=1;k<=aa.m;k++){
                nw.w[i][j]=max(nw.w[i][j],aa.w[i][k]+bb.w[k][j]);
            }
        }
    }
    return nw;
}

int a[MAXN];
int g[MAXN][2];
mat f[MAXN];

int cnt;
int dep[MAXN];
int idx[MAXN];
int fa[MAXN];
int top[MAXN];
int maxn[MAXN];
int son[MAXN];
int siz[MAXN];

mat t[MAXN*4];

vector<int> q[MAXN];

void update(int l,int r,int x,mat v,int id){
    if(l==r){
        t[id]=v;
        return ;
    }
    int mid=(l+r)/2;
    if(mid>=x)update(l,mid,x,v,id*2);
    else update(mid+1,r,x,v,id*2+1);
    t[id]=t[id*2]*t[id*2+1];
}
mat query(int l,int r,int z,int y,int id){
    if(l==z&&r==y)return t[id];
    int mid=(l+r)/2;
    if(mid>=y)return query(l,mid,z,y,id*2);
    else if(mid<z)return query(mid+1,r,z,y,id*2+1);
    else return query(l,mid,z,mid,id*2)*query(mid+1,r,mid+1,y,id*2+1);
}

void dfs_init(int x,int pr){
    fa[x]=pr,dep[x]=dep[pr]+1;
    for(int i=0;i<q[x].size();i++){
        int nx=q[x][i];
        if(nx==pr)continue;
        dfs_init(nx,x);
        siz[x]+=siz[nx];
        if(siz[nx]>=siz[son[x]])son[x]=nx;
        f[x].w[1][1]+=f[nx].w[2][1];
        f[x].w[2][1]+=max(f[nx].w[1][1],f[nx].w[2][1]);
    }
    f[x].w[1][1]+=a[x];
    siz[x]++;
}

mat G(int x){
    mat nw;
    nw.m=nw.n=2;
    nw.w[1][1]=-100000000,nw.w[1][2]=g[x][1];
    nw.w[2][1]=nw.w[2][2]=g[x][0];
    return nw;
}

void calg(int x){
    g[x][0]=g[x][1]=0;
    for(int i=0;i<q[x].size();i++){
        int nx=q[x][i];
        if(nx==fa[x]||nx==son[x])continue;
        g[x][0]+=max(f[nx].w[1][1],f[nx].w[2][1]);
        g[x][1]+=f[nx].w[2][1];
    }
    g[x][1]+=a[x];
}

void dfs_link(int x,int pr){
    cnt++,idx[x]=cnt;
    maxn[top[x]]=x;
    if(son[x]){
        top[son[x]]=top[x];
        dfs_link(son[x],x);
    }
    for(int i=0;i<q[x].size();i++){
        int nx=q[x][i];
        if(nx==pr||nx==son[x])continue;
        top[nx]=nx;
        dfs_link(nx,x);
    }
    calg(x);
    update(1,n,idx[x],G(x),1);
}

void up(int x,int v){
    a[x]=v;
    while(x)
    {
        calg(x);
        update(1,cnt,idx[x],G(x),1);
        f[top[x]]=query(1,n,idx[top[x]],idx[maxn[top[x]]],1)*f[son[maxn[top[x]]]];
        x=fa[top[x]];
    }
}

int main()
{
    scanf("%d%d",&n,&m);
    f[0].n=2,f[0].m=1;f[0].w[1][1]=f[0].w[2][1]=0;
    for(int i=1;i<=n;i++){
        f[i].n=2;f[i].m=1;f[i].w[1][1]=f[i].w[2][1]=0;
        scanf("%d",&a[i]);
    }
    for(int i=1;i<n;i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        q[x].push_back(y);
        q[y].push_back(x);
    }
    dfs_init(1,0);
    top[1]=1;
    dfs_link(1,0);

    for(int op=1;op<=m;op++)
    {
        int x,v;
        scanf("%d%d",&x,&v);
        up(x,v);
        printf("%d\n",max(f[1].w[1][1],f[1].w[2][1]));
    }

    return 0;
}

原文地址:https://www.cnblogs.com/redegg/p/11745872.html