【比赛】NOIP2018 保卫王国

DDP模板题

#include<bits/stdc++.h>
#define ui unsigned int
#define ll long long
#define db double
#define ld long double
#define ull unsigned long long
#define ft first
#define sd second
#define pb(a) push_back(a)
#define mp(a,b) std::make_pair(a,b)
#define ITR(a,b) for(auto a:b)
#define REP(a,b,c) for(register int a=(b),a##end=(c);a<=a##end;++a)
#define DEP(a,b,c) for(register int a=(b),a##end=(c);a>=a##end;--a)
const int MAXN=100000+10;
const ll inf=1e18,vinf=1e12;
int n,m,e,beg[MAXN],nex[MAXN<<1],to[MAXN<<1],size[MAXN],hson[MAXN],st[MAXN],ed[MAXN],top[MAXN],fa[MAXN],w[MAXN],cnt;
ll f[MAXN][2],all;
char type[5];
template<typename T> inline bool chkmin(T &x,T y){return y<x?(x=y,true):false;}
template<typename T> inline bool chkmax(T &x,T y){return y>x?(x=y,true):false;}
struct Matrix{
	ll a[2][2];
	Matrix(){
		REP(i,0,1)REP(j,0,1)a[i][j]=-inf;
	};
	inline Matrix operator * (const Matrix &A) const {
		Matrix B;
		REP(i,0,1)REP(k,0,1)REP(j,0,1)chkmax(B.a[i][j],a[i][k]+A.a[k][j]);
		return B;
	};
};
Matrix val[MAXN];
#define Mid ((l+r)>>1)
#define ls rt<<1
#define rs rt<<1|1
#define lson ls,l,Mid
#define rson rs,Mid+1,r
struct Segment_Tree{
	Matrix sum[MAXN<<2];
	inline void PushUp(int rt)
	{
		sum[rt]=sum[ls]*sum[rs];
	}
	inline void Build(int rt,int l,int r)
	{
		if(l==r)sum[rt]=val[l];
		else Build(lson),Build(rson),PushUp(rt);
	}
	inline void Update(int rt,int l,int r,int ps,Matrix k)
	{
		if(l==r)sum[rt]=k;
		else
		{
			if(ps<=Mid)Update(lson,ps,k);
			else Update(rson,ps,k);
			PushUp(rt);
		}
	}
	inline Matrix Query(int rt,int l,int r,int L,int R)
	{
		if(L<=l&&r<=R)return sum[rt];
		else
		{
			if(R<=Mid)return Query(lson,L,R);
			else if(L>Mid)return Query(rson,L,R);
			else return Query(lson,L,R)*Query(rson,L,R);
		}
	}
};
Segment_Tree T;
#undef Mid
#undef ls
#undef rs
#undef lson
#undef rson
template<typename T> inline void read(T &x)
{
	T data=0,w=1;
	char ch=0;
	while(ch!='-'&&(ch<'0'||ch>'9'))ch=getchar();
	if(ch=='-')w=-1,ch=getchar();
	while(ch>='0'&&ch<='9')data=((T)data<<3)+((T)data<<1)+(ch^'0'),ch=getchar();
	x=data*w;
}
template<typename T> inline void write(T x,char ch='')
{
	if(x<0)putchar('-'),x=-x;
	if(x>9)write(x/10);
	putchar(x%10+'0');
	if(ch!='')putchar(ch);
}
template<typename T> inline T min(T x,T y){return x<y?x:y;}
template<typename T> inline T max(T x,T y){return x>y?x:y;}
inline void insert(int x,int y)
{
	to[++e]=y;
	nex[e]=beg[x];
	beg[x]=e;
}
inline void dfs1(int x,int p)
{
	int res=0;
	size[x]=1;fa[x]=p;
	for(register int i=beg[x];i;i=nex[i])
		if(to[i]==p)continue;
		else
		{
			dfs1(to[i],x);
			size[x]+=size[to[i]];
			if(chkmax(res,size[to[i]]))hson[x]=to[i];
		}
}
inline void dfs2(int x,int tp)
{
	top[x]=tp;st[x]=++cnt;
	val[cnt].a[0][0]=val[cnt].a[0][1]=f[x][0];
	val[cnt].a[1][0]=f[x][1];
	if(hson[x])
	{
		val[cnt].a[0][0]-=max(f[hson[x]][0],f[hson[x]][1]);
		val[cnt].a[0][1]=val[cnt].a[0][0];
		val[cnt].a[1][0]-=f[hson[x]][0];
		dfs2(hson[x],tp);ed[x]=ed[hson[x]];
	}
	else ed[x]=cnt;
	for(register int i=beg[x];i;i=nex[i])
		if(to[i]==fa[x]||to[i]==hson[x])continue;
		else dfs2(to[i],to[i]);
}
inline void dfs(int x)
{
	f[x][1]=w[x];
	for(register int i=beg[x];i;i=nex[i])
		if(to[i]==fa[x])continue;
		else
		{
			dfs(to[i]);
			f[x][1]+=f[to[i]][0];
			f[x][0]+=max(f[to[i]][0],f[to[i]][1]);
		}
}
inline void init()
{
	dfs1(1,0);dfs(1);dfs2(1,1);
	T.Build(1,1,n);
}
inline void solve(int u,ll v)
{
	Matrix A,B,C;
	B=T.Query(1,1,n,st[u],st[u]);
	A=T.Query(1,1,n,st[top[u]],ed[u]);
	B.a[1][0]+=v;
	T.Update(1,1,n,st[u],B);
	while(u)
	{
		B=T.Query(1,1,n,st[top[u]],ed[u]);
		u=fa[top[u]];
		if(!u)break;
		C=T.Query(1,1,n,st[u],st[u]);
		C.a[0][0]+=max(B.a[0][0],B.a[1][0])-max(A.a[0][0],A.a[1][0]);
		C.a[0][1]=C.a[0][0];
		C.a[1][0]+=B.a[0][0]-A.a[0][0];
		A=T.Query(1,1,n,st[top[u]],ed[u]);
		T.Update(1,1,n,st[u],C);
	}
}
inline ll value(int ot1,int ot2)
{
	Matrix A=T.Query(1,1,n,st[1],ed[1]);
	return max(A.a[0][0],A.a[1][0])+(ot1?0:-vinf)+(ot2?0:-vinf);
}
int main()
{
	freopen("defense.in","r",stdin);
	freopen("defense.out","w",stdout);
	read(n);read(m);scanf("%s",type);
	REP(i,1,n)read(w[i]),all+=w[i];
	REP(i,1,n-1)
	{
		int u,v;read(u);read(v);
		insert(u,v);insert(v,u);
	}
	init();
	while(m--)
	{
		int a,x,b,y;read(a);read(x);read(b);read(y);
		if((fa[a]==b||fa[b]==a)&&!x&&!y)
		{
			puts("-1");
			continue;
		}
		solve(a,x?-vinf:vinf);
		solve(b,y?-vinf:vinf);
		printf("%lld
",all-value(x,y));
		solve(a,x?vinf:-vinf);
		solve(b,y?vinf:-vinf);
	}
	return 0;
}
原文地址:https://www.cnblogs.com/hongyj/p/10206244.html