数据结构优化DP

数据结构优化DP

用途

在DP的转移中需要用到某一个阶段的最值的时候可以用线段树和树状数组等数据结构进行维护,在O(1)或O(log N) 的时间复杂度内完成转移

例题

cleaning shifts

分析

首先设计出状态,dp[x]表示从m清理到x所付出的最小代价
很显然,状态转移方程为

很显然,我们的每一次的转移都会用到一个区间的最小值,所以考虑运用线段树进行优化

build

我们在[m,e]上建立一颗线段树,存储DP的最小值

change

当我们更新完一个DP的值的时候,就在线段树中插入这个值

ask

每一次状态转移我们都需要在区间查找最小值

代码
#include<bits/stdc++.h>
using namespace std;

const int MAXN=1e5+5,MAXM=9e5+5;
struct Node
{
	int t1,t2,s;
}cow[MAXN];
struct node
{
	int l,r,val;
}lst[MAXM];
int n,s,e,dp[MAXM],ans;

void build_tree(int id,int l,int r)
{
	lst[id].l=l; lst[id].r=r;
	if( l==r )
	{
		lst[id].val=dp[l];
		return;
	}
	
	int mid=(l+r)/2;
	build_tree(id*2,l,mid);
	build_tree(id*2+1,mid+1,r);
	lst[id].val=min(lst[id*2].val,lst[id*2+1].val);
	return;
}

void change_tree(int id,int ver,int val)
{
	if( lst[id].l==lst[id].r )
	{
		lst[id].val=dp[ver];
		return;
	}
	
	int mid=(lst[id].l+lst[id].r)/2;
	if( mid >= ver ) change_tree(id*2,ver,val);
	else change_tree(id*2+1,ver,val);
	lst[id].val=min(lst[id*2].val,lst[id*2+1].val);
	return;
}

int ask_tree(int id,int l,int r)
{
	if( lst[id].l==lst[id].r ) return lst[id].val;
	
	int mid=(lst[id].l+lst[id].r)/2,tem=0x7f7f7f7f;
	if( mid >= l ) tem=min(tem,ask_tree(id*2,l,r));
	if( mid <= r ) tem=min(tem,ask_tree(id*2+1,l,r));
	return tem;
}

bool cmp(Node x,Node y)
{
	return x.t2 < y.t2;
}

int main()
{
	scanf("%d%d%d",&n,&s,&e);
	for(int i=1;i<=n;i++) scanf("%d%d%d",&cow[i].t1,&cow[i].t2,&cow[i].s);
	sort(cow+1,cow+n+1,cmp);
	memset(dp,0x7f7f7f7f,sizeof(dp));
	dp[s]=0;
	build_tree(1,s,e);
	for(int i=1;i<=n;i++)
	{
		int tem=ask_tree(1,cow[i].t1-1,cow[i].t2);
		dp[cow[i].t2]=tem+cow[i].s;
		if( cow[i].t2 >= e ) {ans=dp[cow[i].t2]; break;}
		change_tree(1,cow[i].t2,dp[i]);
	}
	if( ans==2139075787 ) printf("-1");
	else printf("%d",ans);
	return 0;
}

the battle of chibi

分析

实际上是给定一个长度为N的数列,求数列中有多少个长度为M的严格递增子序列
首先设计状态 dp[i] [j] 表示前j个数中以第j个数为结尾的长度为i 的严格递增序列有多少个

状态转移方程为:

很显然,在状态转移的时候要多次用到前缀和,所以想到树状数组,因为数据范围太大,所以先将a数组离散化到disc数组,然后用c[x]表示disc[x]的前缀和

add

将c[disc[j]]增加dp[i-1] [j]

ask

查询disc[j]的前缀和

代码
#include<bits/stdc++.h>
using namespace std;

const int MAXN=1e3+5,mod=1e9+7;
struct Node
{
	int id,val;
}a[MAXN],b[MAXN];
int c[MAXN],disc[MAXN],n,m,t,dp[MAXN][MAXN];

int lowbit(int x)
{
	return x & -x;
}

bool cmp(Node x,Node y)
{
	return x.val == y.val ? x.id > y.id : x.val < y.val;
}

int ask(int x)
{
	int tem=0;
	while( x )
	{
		tem+=c[x]; tem%=mod;
		x-=lowbit(x);
	}
	return tem;
}

void add(int x,int y)
{
	while( x <= n+1 )
	{
		c[x]+=y; c[x]%=mod;
		x+=lowbit(x);
	}
	return;
}

void work(int k)
{
	memset(a,0,sizeof(a)); memset(dp,0,sizeof(dp));
	memset(disc,0,sizeof(disc));
	dp[0][0]=1;
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++) scanf("%d",&a[i].val),a[i].id=i,b[i]=a[i];
	sort(b+1,b+n+1,cmp);
	for(int i=1;i<=n;i++) disc[b[i].id]=i+1;
	for(int i=1;i<=m;i++)
	{
		memset(c,0,sizeof(c));
		add(1,dp[i-1][0]);
		for(int j=1;j<=n;j++)
		{
			dp[i][j]=ask(disc[j]-1);
			add(disc[j],dp[i-1][j]);
		}
	}
	int ans=0;
	for(int i=1;i<=n;i++) ans+=dp[m][i],ans%=mod;
	printf("Case #%d: %d
",k,ans%mod);
	return;
}

int main()
{
	scanf("%d",&t);
	for(int i=1;i<=t;i++) work(i);
	return 0;
}

作业题

fence obstacle course
estimation

原文地址:https://www.cnblogs.com/BZDYL/p/12093386.html