动态DP(动态树分治)

Link
首先我们有一个静态的dp。
(f_{u,0/1})表示只考虑(u)的子树,(u)不选/选的答案。
那么很显然有:

[egin{aligned} f_{u,0}&=sumlimits_{vin son_u}max(f_{v,0},f_{v,1})\ f_{u,1}&=w_u+sumlimits_{vin son_u}f_{v,0} end{aligned} ]

考虑利用重链剖分来进行这个过程,设(h_u)表示(u)的重儿子,(g_{u,0/1})表示(f_{u,0/1})在不考虑(h_u)子树情况下的答案。
那么有:

[egin{aligned} g_{u,0}&=sumlimits_{vin son_uwedge v e h_u}max(f_{v,0},f_{v,1})\ g_{u,1}&=w_u+sumlimits_{vin son_uwedge v e h_u}f_{v,0}\ f_{u,0}&=max(f_{h_u,0},f_{h_u,1})+g_{u,0}\ f_{u,1}&=f_{h_u,0}+g_{u,1} end{aligned} ]

对于一条重链,实际上我们只关心(f_{top})
假如我们已经求出了链上所有点的(g),那么我们可以做一个序列dp得到(f_{top})
实际上我们可以把这个序列dp的转移写成矩阵乘法的形式。
定义(C=AB)为满足(C_{i,j}=maxlimits_k(A_{i,k}+B_{k,j}))的矩阵,那么有:

[egin{pmatrix}f_{h_u,0}&f_{h_u,1}end{pmatrix}egin{pmatrix}g_{u,0}&g_{u,1}\g_{u,0}&-inftyend{pmatrix}=egin{pmatrix}f_{u,0}&f_{u,1}end{pmatrix} ]

注意到新定义的矩阵乘法仍然具有结合律,因此我们可以用线段树维护每条重链上的矩阵的区间积。为了方便我们在线段树外同时记录每个点的转移矩阵。
这样做的时间复杂度为(O(nlog n+qlog^2n))

#include<cctype>
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
const int N=100007,inf=1e9;
char ibuf[1<<23|1],*iS=ibuf;
int n,m,val[N],fa[N],size[N],son[N],top[N],ch[N],dfn[N],id[N],f[N][2];
std::vector<int>e[N];
struct matrix{int a[2][2];int*operator[](int x){return a[x];}}t[4*N],a[N];
matrix operator*(matrix a,matrix b)
{
    matrix c;
    c[0][0]=std::max(a[0][0]+b[0][0],a[0][1]+b[1][0]),c[0][1]=std::max(a[0][0]+b[0][1],a[0][1]+b[1][1]);
    c[1][0]=std::max(a[1][0]+b[0][0],a[1][1]+b[1][0]),c[1][1]=std::max(a[1][0]+b[0][1],a[1][1]+b[1][1]);
    return c;
}
int read(){int x=0,f=1;while(isspace(*iS))++iS;if(*iS=='-')++iS,f=-1;while(isdigit(*iS))(x*=10)+=*iS++&15;return f*x;}
void dfs1(int u)
{
    size[u]=1;
    for(int v:e[u]) if(v^fa[u]) if(fa[v]=u,dfs1(v),size[u]+=size[v],size[v]>size[son[u]]) son[u]=v;
}
void dfs2(int u,int tp)
{
    static int tim;id[dfn[u]=++tim]=ch[u]=u,top[u]=tp;
    if(son[u]) dfs2(son[u],tp),ch[u]=ch[son[u]];
    for(int v:e[u]) if(v^fa[u]&&v^son[u]) dfs2(v,v);
}
void dfs3(int u)
{
    f[u][1]=val[u];
    for(int v:e[u]) if(v^fa[u]) dfs3(v),f[u][0]+=std::max(f[v][0],f[v][1]),f[u][1]+=f[v][0];
}
matrix get(int u)
{
    int g0=0,g1=val[u];
    for(int v:e[u]) if(v^fa[u]&&v^son[u]) g0+=std::max(f[v][0],f[v][1]),g1+=f[v][0];
    return {g0,g0,g1,-inf};
}
#define ls p<<1
#define rs p<<1|1
#define mid ((l+r)/2)
void pushup(int p){t[p]=t[ls]*t[rs];}
void build(int p,int l,int r)
{
    if(l==r) return a[l]=t[p]=get(id[l]),void();
    build(ls,l,mid),build(rs,mid+1,r),pushup(p);
}
void update(int p,int l,int r,int x)
{
    if(l==r) return t[p]=a[l],void();
    x<=mid? update(ls,l,mid,x):update(rs,mid+1,r,x),pushup(p);
}
matrix query(int p,int l,int r,int L,int R)
{
    if(L<=l&&r<=R) return t[p];
    if(R<=mid) return query(ls,l,mid,L,R);
    if(L>mid) return query(rs,mid+1,r,L,R);
    return query(ls,l,mid,L,R)*query(rs,mid+1,r,L,R);
}
#undef ls
#undef rs
#undef mid
void modify(int u,int w)
{
    a[dfn[u]][1][0]+=w-val[u],val[u]=w;
    while(u)
    {
	matrix p=query(1,1,n,dfn[top[u]],dfn[ch[u]]);
	update(1,1,n,dfn[u]);
	matrix q=query(1,1,n,dfn[top[u]],dfn[ch[u]]);
	if(!(u=fa[top[u]]))break;
	int x=dfn[u],g0=p[0][0],g1=p[1][0],f0=q[0][0],f1=q[1][0];
	a[x][0][0]=a[x][0][1]=a[x][0][0]+std::max(f0,f1)-std::max(g0,g1),a[x][1][0]=a[x][1][0]+f0-g0;
    }
}
void work()
{
    int u=read(),w=read();modify(u,w);
    matrix ans=query(1,1,n,dfn[1],dfn[ch[1]]);
    printf("%d
",std::max(ans[0][0],ans[1][0]));
}
int main()
{
    fread(ibuf,1,1<<23,stdin);
    n=read(),m=read();
    for(int i=1;i<=n;++i) val[i]=read();
    for(int i=1,u,v;i<n;++i) u=read(),v=read(),e[u].push_back(v),e[v].push_back(u);
    dfs1(1),dfs2(1,1),dfs3(1);
    build(1,1,n);
    for(int i=1;i<=m;++i) work();
}

还有一个叫做全局平衡二叉树的东西。
类似于LCT,大致思想还是用Splay维护每一条重链。
注意到树是静态的,因此并不需要支持rotate等改变数的形态的操作,因此常数会小很多。
对于每条重链,为了建出较为平衡的bst,我们按轻儿子(size)之和的加权重心递归建树。

#include<cctype>
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
const int N=1000007,inf=1e9;
char ibuf[1<<27|1],*iS=ibuf;
int n,q,val[N],son[N],sz[N];std::vector<int>e[N];
int read(){int x=0,f=1;while(isspace(*iS))++iS;if(*iS=='-')++iS,f=-1;while(isdigit(*iS))(x*=10)+=*iS++&15;return f*x;}
struct matrix
{
    int a[2][2];
    matrix(){a[0][0]=a[0][1]=a[1][0]=a[1][1]=-inf;}
    int*operator[](int x){return a[x];}
    int cal(){return std::max(std::max(a[0][0],a[0][1]),std::max(a[1][0],a[1][1]));}
};
matrix operator*(matrix a,matrix b)
{
    matrix c;
    c[0][0]=std::max(a[0][0]+b[0][0],a[0][1]+b[1][0]),c[0][1]=std::max(a[0][0]+b[0][1],a[0][1]+b[1][1]);
    c[1][0]=std::max(a[1][0]+b[0][0],a[1][1]+b[1][0]),c[1][1]=std::max(a[1][0]+b[0][1],a[1][1]+b[1][1]);
    return c;
}
void dfs1(int u,int fa)
{
    sz[u]=1;
    for(int v:e[u]) if(v^fa) if(dfs1(v,u),sz[u]+=sz[v],sz[v]>sz[son[u]]) son[u]=v;
}
struct BST
{
    int root,top,ch[N][2],fa[N],stk[N],vis[N],size[N];matrix f[N],sum[N];
    void init(){f[0][0][0]=f[0][1][1]=sum[0][0][0]=sum[0][1][1]=0;for(int i=1;i<=n;++i)f[i][0][1]=val[i],f[i][0][0]=f[i][1][0]=0;}
    int nroot(int p){return ch[fa[p]][0]==p||ch[fa[p]][1]==p;}
    void pushup(int p){sum[p]=sum[ch[p][0]]*f[p]*sum[ch[p][1]];}
    void merge(int u,int v){f[u][1][0]+=sum[v].cal(),f[u][0][0]=f[u][1][0],f[u][0][1]+=std::max(sum[v][0][0],sum[v][1][0]),fa[v]=u;}
    int build(int l,int r)
    {
	if(l>r) return 0;
	int tot=0;for(int i=l;i<=r;++i)tot+=size[stk[i]];
        for(int i=l,now=size[stk[i]],ls,rs;i<=r;++i,now+=size[stk[i]])
            if(2*now>=tot)
		return ls=build(l,i-1),rs=build(i+1,r),ch[stk[i]][0]=rs,ch[stk[i]][1]=ls,fa[ls]=fa[rs]=stk[i],pushup(stk[i]),stk[i];
    }
    int build(int p)
    {
	for(int u=p;u;u=son[u]) vis[u]=1;
	for(int u=p;u;u=son[u]) for(int v:e[u]) if(!vis[v]) merge(u,build(v));
	top=0;for(int u=p;u;u=son[u])stk[++top]=u,size[u]=sz[u]-sz[son[u]];
	return build(1,top);
    }
    void update(int u,int w)
    {
	f[u][0][1]+=w-val[u],val[u]=w;
	for(int v=u;v;v=fa[v])
	    if(!nroot(v)&&fa[v]) 
	    {
		f[fa[v]][0][0]-=sum[v].cal(),f[fa[v]][1][0]=f[fa[v]][0][0],f[fa[v]][0][1]-=std::max(sum[v][0][0],sum[v][1][0]);
		pushup(v);
		f[fa[v]][0][0]+=sum[v].cal(),f[fa[v]][1][0]=f[fa[v]][0][0],f[fa[v]][0][1]+=std::max(sum[v][0][0],sum[v][1][0]);
	    }
	    else pushup(v);
    }
}bst;
int main()
{
    fread(ibuf,1,1<<27,stdin);
    n=read(),q=read();
    for(int i=1;i<=n;++i) val[i]=read();
    for(int i=1,u,v;i<n;++i) u=read(),v=read(),e[u].push_back(v),e[v].push_back(u);
    dfs1(1,0),bst.init(),bst.root=bst.build(1);
    for(int i=1,u,w;i<=q;++i) u=read(),w=read(),bst.update(u,w),printf("%d
",bst.sum[bst.root].cal());
}
原文地址:https://www.cnblogs.com/cjoierShiina-Mashiro/p/12845678.html