[模板][P4719]动态dp

Description:

给定一棵n个点的树,点带点权。

有m次操作,每次操作给定x,y,表示修改点x的权值为y。

你需要在每次操作之后求出这棵树的最大权独立集的权值大小。

Hint:

(n,m<=10^5)

Solution:

详见代码

#include<bits/stdc++.h>
#define ls p<<1
#define rs p<<1|1
using namespace std;
typedef long long ll;
const int mxn=1e5+5;
const ll inf=1e18;
struct ed {
	int to,nxt;
}t[mxn<<1];
int n,m,cnt,tot;
int a[mxn],f[mxn],hd[mxn],sz[mxn],rk[mxn],dfn[mxn],top[mxn],bot[mxn],son[mxn];
ll dp[mxn][2];

inline void checkmax(ll &x,ll y) {if(x<y) x=y;}

struct mat {
	ll b[2][2];	
	friend mat operator * (mat x,mat y) {
		mat res={-inf,-inf,-inf,-inf}; //初始矩阵要赋为-inf
		for(int i=0;i<2;++i)
			for(int j=0;j<2;++j) 
				for(int k=0;k<2;++k)
					checkmax(res.b[i][j],x.b[i][k]+y.b[k][j]); //根据dp式子写矩阵运算
		return res;
	}
}val[mxn],w[mxn<<2];

inline void add(int u,int v) {
	t[++cnt]=(ed){v,hd[u]},hd[u]=cnt;
}

void dfs1(int u,int fa) 
{
	sz[u]=1; f[u]=fa;
	for(int i=hd[u];i;i=t[i].nxt) {
		int v=t[i].to;
		if(v==fa) continue ;
		dfs1(v,u);
		sz[u]+=sz[v];
		if(sz[v]>sz[son[u]]) son[u]=v;
	}
}

void dfs2(int u,int tp)
{
	top[u]=tp; dfn[u]=++tot; rk[tot]=u;
	if(son[u]) dfs2(son[u],tp),bot[u]=bot[son[u]];
	else bot[u]=u;
	for(int i=hd[u];i;i=t[i].nxt) {
		int v=t[i].to;
		if(v==f[u]||v==son[u]) continue ;
		dfs2(v,v);
	}
}

void init(int u)
{
	dp[u][1]=a[u];
	for(int i=hd[u];i;i=t[i].nxt) {
		int v=t[i].to;
		if(v==f[u]) continue ;
		init(v);
		dp[u][1]+=dp[v][0];
		dp[u][0]+=max(dp[v][0],dp[v][1]);
	}
}
 
void build(int l,int r,int p)
{
	if(l==r) {
		int u=rk[l]; ll g0=0,g1=a[u];
		for(int i=hd[u];i;i=t[i].nxt) {
			int v=t[i].to;
			if(v==f[u]||v==son[u]) continue ;
			g0+=max(dp[v][0],dp[v][1]),g1+=dp[v][0];
		}
		val[l]=w[p]=(mat){g0,g0,g1,-inf};
		return ;
	}
	int mid=(l+r)>>1;
	build(l,mid,ls); build(mid+1,r,rs);
	w[p]=w[ls]*w[rs];
}

void update(int l,int r,int x,int p) 
{
	if(l==r) {
		w[p]=val[l];
		return ;
	}
	int mid=(l+r)>>1;
	if(x<=mid) update(l,mid,x,ls);
	else update(mid+1,r,x,rs);
	w[p]=w[ls]*w[rs];
}

mat query(int l,int r,int ql,int qr,int p)
{
	if(ql<=l&&r<=qr) return w[p];
	int mid=(l+r)>>1;  
	if(qr<=mid) return query(l,mid,ql,qr,ls);
	if(ql>mid) return query(mid+1,r,ql,qr,rs);
	return query(l,mid,ql,qr,ls)*query(mid+1,r,ql,qr,rs);
}

mat getmat(int x) {
    return query(1,n,dfn[top[x]],dfn[bot[x]],1);
}

void modify(int x,int y)
{
	val[dfn[x]].b[1][0]+=y-a[x],a[x]=y;
	mat las,nw;
	while(x) {
		las=getmat(top[x]); update(1,n,dfn[x],1);
		nw=getmat(top[x]); x=f[top[x]];
		val[dfn[x]].b[0][0]+=max(nw.b[0][0],nw.b[1][0])-max(las.b[0][0],las.b[1][0]);
		val[dfn[x]].b[0][1]=val[dfn[x]].b[0][0];
		val[dfn[x]].b[1][0]+=nw.b[0][0]-las.b[0][0]; //+=不要写成=
	}
}

void solve(int x) {
	mat ans=getmat(x);
	printf("%lld
",max(ans.b[0][0],ans.b[1][0]));
}

int main()
{
	scanf("%d%d",&n,&m); int x,y;
	for(int i=1;i<=n;++i) scanf("%d",&a[i]);
	for(int i=1;i<n;++i) {
		scanf("%d%d",&x,&y); 
		add(x,y); add(y,x);
	}

	dfs1(1,0); dfs2(1,1); init(1); build(1,n,1);
	for(int i=1;i<=m;++i) {
		scanf("%d%d",&x,&y);
		modify(x,y); 
		solve(1);
	}
	return 0;
}
原文地址:https://www.cnblogs.com/list1/p/10426012.html