CF833B-线段树优化DP

CF833B-线段树优化DP

题意

将一个长为(n)的序列分成(k)段,每段贡献为其中不同数字的个数,求最大贡献和。

思路

此处感谢@gxy001 聚铑的精彩讲解

先考虑暴力DP,可以想到一个时空复杂度(O(n^2k))的方法,即记录前i个数字分成了j段。我们现在来思考几个问题来优化这个操作:

  1. 对于一个数字,它对那些地方实际有贡献?
  2. 每次分割出一个区间段对后续操作有影响的位置在哪?
  3. 每次转移都从哪些地方继承?

下来一一解答这些问题。

  1. 对于一个数字,它能产生贡献的区间其实就是该数字上一次出现的位置的后一位到它本身的位置。
  2. 对于每次划分,它以前的位置的贡献已经被考虑,所以我们只能考虑后面的位置。
  3. 相应的,每次转移会继承前面所有DP值的最大值

那么我们可以将k提出来,每次循环继承上一次所有的dp值。因为只考虑从前面转移dp值,所以不会对之前的决策产生影响,所以是正确的。

看看1、3问题的答案,是不是想到了RMQ和区间赋值?

于是我们可以通过线段树来实现DP优化。

具体来讲,迭代k次,每次线段树更新为上一次序列的dp值,然后从前往后扫,每个数会对其上述区间产生1的贡献,转移继承前面所有dp值的最大值即可。

时间复杂度将一维优化为log。

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cctype>
#include<algorithm>
#include<cmath>
using namespace std;
inline int read(){
	int w=0,x=0;char c=getchar();
	while(!isdigit(c))w|=c=='-',c=getchar();
	while(isdigit(c))x=(x<<3)+(x<<1)+(c^48),c=getchar();
	return w?-x:x;
}
namespace star
{
	const int maxn=35005;
	int n,k,cur[maxn],pre[maxn],f[maxn][60];
	struct SegmentTree{
		#define ls (ro<<1)
		#define rs (ro<<1|1)
		struct tree{
			int l,r,tag,v;
		}e[maxn<<2];
		inline void pushup(int ro){
			e[ro].v=max(e[ls].v,e[rs].v);
		}
		inline void pushdown(int ro){
			e[ls].tag+=e[ro].tag;e[rs].tag+=e[ro].tag;
			e[ls].v+=e[ro].tag;e[rs].v+=e[ro].tag;
			e[ro].tag=0;
		}
		void build(int ro,int l,int r){
			e[ro].l=l,e[ro].r=r;
			if(l==r)return;
			int mid=l+r>>1;
			build(ls,l,mid);
			build(rs,mid+1,r);
		}
		void rebuild(int tim,int ro){
			int l=e[ro].l,r=e[ro].r;
			e[ro].tag=0;
			if(l==r){
				e[ro].v=f[l][tim];return;
			}
			rebuild(tim,ls);rebuild(tim,rs);
			pushup(ro);
		}
		void update(int ro,int x,int y){
			int l=e[ro].l,r=e[ro].r;
			if(l>=x and r<=y){
				e[ro].v+=1;
				e[ro].tag+=1;return;
			}
			pushdown(ro);
			int mid=l+r>>1;
			if(mid>=x)update(ls,x,y);
			if(mid<y)update(rs,x,y);
			pushup(ro);
		}
		int query(int ro,int x,int y){
			int l=e[ro].l,r=e[ro].r;
			if(l==x and r==y)return e[ro].v;
			pushdown(ro);
			int mid=l+r>>1;
			if(mid<x)return query(rs,x,y);
			else if(mid>=y)return query(ls,x,y);
			else return max(query(ls,x,mid),query(rs,mid+1,y));
		}
		#undef ls
		#undef rs 
	}T;
	inline void work(){
		n=read(),k=read();
		for(int x,i=1;i<=n;i++)x=read(),pre[i]=cur[x],cur[x]=i;
		T.build(1,0,n);
		for(int i=1;i<=k;i++){
			T.rebuild(i-1,1);
			for(int x=1;x<=n;x++) T.update(1,pre[x],x-1),f[x][i]=T.query(1,0,x-1);
		}
		printf("%d",f[n][k]);
	}
}
signed main(){
	star::work();
	return 0;
}
原文地址:https://www.cnblogs.com/BrotherHood/p/13791099.html