POJ 3017 DP + 单调队列 + 堆

题意:给你一个长度为n的数列,你需要把这个数列分成几段,每段的和不超过m,问各段的最大值之和的最小值是多少?

思路:dp方程如下:设dp[i]为把前i个数分成合法的若干段最大值的最小值是多少。dp转移比较显然,dp[i] = min{dp[j] + max(a[j + 1] , a[j + 2] ... + a[i])}, 其中a[j + 1] + a[j + 2] +... + a[i] <= m;这个dp转移是O(n^2)的,我们需要用单调队列优化。单调队列维护的是a值单调递减的序列(要保证与i位置的区间和小于等于m)而单调队列的对头不一定是最优的。需要找出单调队列中的最小值,这个需要用堆或者线段树来维护一下。dp[i]的转移分为两种,一种是j + 1 到i的和正好小于m的这种转移,另一种是单调队列中的最小值,两者取min就是当前状态的最小值。

这题有两个点需要注意。1:若j在单调队列里,那么max(a[j + 1] , a[j + 2] ... + a[i])是单调队列里的下一个值。2:因为max(a[j + 1] , a[j + 2] ... + a[i])这个值是有可能随i的变化而变化,所以,如果用堆去维护单调队列中的值, 需要对每个j记录一下最新的max(a[j + 1] , a[j + 2] ... + a[i]), 不能直接扔到堆里就完事了。。。或者,使用pbds中的堆,它支持对堆中元素的修改,然而POJ不支持pbds。。。。

一般堆的代码:

#include <algorithm>
#include <iostream>
#include <cstdio>
#include <cstring>
#include <queue>
#define LL long long
#define pii pair<int, int>
#define lowbit(x) (x << 1)
#define ls(x) (x << 1)
#define rs(x) ((x << 1) | 1)
#define db double
#define pli pair<LL, int>
using namespace std;
const int maxn = 100010;
struct node {
	LL val;
	int pos;
	bool operator < (const node & rhs) const {
		return val > rhs.val;
	}
};
priority_queue<node> Q;
LL dp[maxn], a[maxn];
int q[maxn];
bool v[maxn];
LL val[maxn];
LL sum[maxn];
void change(LL x, int y) {
	Q.push((node){x, y});
	val[y] =  x;
}
int main() {
	int n;
	LL m;
	scanf("%d%lld", &n, &m);
	for (int i = 1; i <= n; i++) {
		scanf("%lld", &a[i]);
		sum[i] = sum[i - 1] + a[i];
	}
	int l = 1, r = 1, ans = 0, pos = 0;
	dp[1] = a[1];
	q[1] = 1;
	if(a[1] > m) ans = -1;
	for (int i = 2; i <= n; i++) {
		while(sum[i] - sum[pos] > m) pos++;
		if(pos == i) {
			ans = -1;
			break;
		}
		while(l <= r && sum[i] - sum[q[l] - 1] > m) {
			v[q[l]] = 1;
			l++;
		}
		while(l <= r && a[q[r]] <= a[i]) {
			v[q[r]] = -1;
			r--;
		}
		if(l <= r)
			change(dp[q[r]] + a[i], q[r]);
		q[++r] = i;
		dp[i] = dp[pos] + a[q[l]];
		while(Q.size() && (v[Q.top().pos] == 1 || val[Q.top().pos] != Q.top().val)) {
			Q.pop();
		}
		if(Q.size()) {
			dp[i] = min(dp[i], Q.top().val);
		}
	}
	if(ans == -1) {
		printf("%d
", ans);
	} else {
		printf("%lld
", dp[n]);
	}
}

pb_ds的代码(应该是对的吧)

#include <bits/stdc++.h>
#define LL long long
#define pii pair<int, int>
#define lowbit(x) (x << 1)
#define ls(x) (x << 1)
#define rs(x) ((x << 1) | 1)
#define db double
#define pli pair<LL, int>
#include <ext/pb_ds/priority_queue.hpp>
using namespace std;
using namespace __gnu_pbds;
const int maxn = 100010;
struct node {
	LL val;
	int pos;
	bool operator < (const node & rhs) const {
		return val > rhs.val;
	}
};
typedef __gnu_pbds::priority_queue<node> Heap;
Heap Q;
Heap::point_iterator id[maxn];
LL dp[maxn], a[maxn];
int q[maxn];
bool v[maxn];
LL sum[maxn];
void change(LL x, int y) {
	if(id[y] != 0)Q.modify(id[y], (node){x, y});
	else id[y] = Q.push((node){x, y});
}
int main() {
	int n, m;
	//freopen("17.in", "r", stdin);
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n; i++) {
		scanf("%d", &a[i]);
		sum[i] = sum[i - 1] + a[i];
	}
	int l = 1, r = 1, ans = 0, pos = 0;
	dp[1] = a[1];
	q[1] = 1;
	if(a[1] > m) ans = -1;
	for (int i = 2; i <= n; i++) {
		while(sum[i] - sum[pos] > m) pos++;
		if(pos == i) {
			ans = -1;
			break;
		}
		while(l <= r && sum[i] - sum[q[l] - 1] > m) {
			v[q[l]] = 1;
			l++;
		}
		if(l > r) {
			ans = -1;
			break;
		}
		while(l <= r && a[q[r]] <= a[i]) {
			v[q[r]] = 1;
			r--;
		}
//		Q.push(make_pair(dp[pos] + a[i], a[q[l]]));
//		printf("%d
", Q.size());
		if(l <= r)
			change(dp[q[r]] + a[i], q[r]);
		q[++r] = i;
		dp[i] = dp[pos] + a[q[l]];
		while(Q.size() && v[Q.top().pos]) {
			id[Q.top().pos] = 0;
			Q.pop();
		}
		if(Q.size()) {
			//printf("%lld %d
", Q.top().val, Q.top().pos);
			dp[i] = min(dp[i], Q.top().val);
		}
	}
	for (int i = 1; i <= n; i++)
        printf("%d %lld
", i, dp[i]);
	if(ans == -1) {
		printf("%d
", ans);
	} else {
		printf("%lld
", dp[n]);
	}
}

  

原文地址:https://www.cnblogs.com/pkgunboat/p/10752884.html