线段树教做人系列(3) HDU 4913

题意及思路看这篇博客就行了,讲得很详细。

下面是我自己的理解:

如果只有2,没有3的话,做法就很简单了,只需要对数组排个序,然后从小到大枚举最大的那个数。那么它对答案的贡献为(假设这个数排序后的位置是pos)2 ^ (pos - 1) * 2 ^ a[pos]。意思是a[pos]这个数必选,其它比它小的数可选可不选,有2^(pos - 1)种情况。现在相当于变成了一个二维的问题。对于这种问题,我们常见的做法是确定一维,在从前往后扫描某一维时加上另一维对答案的贡献。对于这个题,我们可以按数组b从小到大排序,去计算a的贡献。假设现在扫描到的第pos个位置(二元组(a[i], b[i])已经按数组b排序),我们考虑来计算a[i]对答案的贡献。a对答案的贡献分为2部分,一部分是之前已经出现过的,小于等于a[i]的值,假设一共有x个,那么这部分的贡献为(2 ^ x * 2 ^ a[i]),那么大于a[i]的部分呢?其实和这个式子差不多。对于每个已经出现过,并且大于a[i]的a[j],假设已经出现过的比a[j]小的数有y个,那么贡献为2 ^ (y - 1) * 2 * a[j]。为什么是y - 1? 因为a[i]是必选的。通过观察,我们可以发现,每一个a[j]对答案的贡献,取决当前已经出现过的数中有多少个比它小的数,所以我们可以这样维护:在每次插入一个值时,先询问在这个数之前出现了多少个数(假设有x个),然后插入2 ^ x * 2 ^ a[i],询问[i,n]的区间和,就是这一阶段的答案。之后,要把[i + 1,n]中的数乘2,因为他们的前面都多了一个a[i]。

代码:

#include<bits/stdc++.h>
#define ls(x) (x << 1)
#define rs(x) ((x << 1) | 1)
#define LL long long
using namespace std;
const int maxn = 100010;
const LL mod = 1000000007;
struct node{
	int x, y, rank;
};
bool cmp1(node x, node y) {
	return x.x == y.x ? x.y < y.y : x.x < y.x;
}
bool cmp2(node x, node y) {
	return x.y == y.y ? x.x < y.x : x.y < y.y;
}
node a[maxn];
struct SegementTree {
	LL sum, cnt, lz;
};
SegementTree tr[maxn * 4];
LL qpow(LL x, LL y) {
	LL ans = 1;
	for (; y; y >>= 1) {
		if(y & 1) ans = (ans * x) % mod;
		x = (x * x) % mod;
	}
	return ans;
}
void pushup(int x) {
	tr[x].sum = (tr[ls(x)].sum +tr[rs(x)].sum) % mod;
	tr[x].cnt = (tr[ls(x)].cnt + tr[rs(x)].cnt) % mod;
}
void maintain(int x, int y) {
	tr[x].sum = (tr[x].sum * qpow(2, y)) % mod;
	tr[x].lz += y;
}
void pushdown(int x) {
	if(tr[x].lz) {
		if(tr[ls(x)].cnt) maintain(ls(x), tr[x].lz);
		if(tr[rs(x)].cnt) maintain(rs(x), tr[x].lz);
		tr[x].lz = 0;
	}
}
void build(int x, int l, int r) {
	if(l == r) {
		tr[x].sum = tr[x].cnt = 0;
		return;
	}
	int mid = (l + r) >> 1;
	build(ls(x), l, mid);
	build(rs(x), mid + 1, r);
	pushup(x);
}
void update_cnt(int x, int l, int r, int pos, int y, int z) {
	if(l == r) {
		tr[x].cnt = 1;
		tr[x].sum = (qpow(2, y) * qpow(2, z)) % mod;
		return;
	}
	pushdown(x);
	int mid = (l + r) >> 1;
	if(pos <= mid) update_cnt(ls(x), l, mid, pos, y, z);
	else update_cnt(rs(x), mid + 1, r, pos ,y, z);
	pushup(x);
}
void update_sum(int x, int l, int r, int ql, int qr) {
	if(l >= ql && r <= qr) {
		tr[x].lz++;
		tr[x].sum = (tr[x].sum * 2) % mod;
		return;
	}
	pushdown(x);
	int mid = (l + r) >> 1;
	if(ql <= mid) update_sum(ls(x), l, mid, ql, qr);
	if(qr > mid) update_sum(rs(x), mid + 1, r, ql, qr);
	pushup(x);
}
LL query_cnt(int x, int l, int r, int ql, int qr) {
	if(l >= ql && r <= qr) {
		return tr[x].cnt;
	}
	int mid = (l + r) >> 1;
	pushdown(x);
	LL ans = 0;
	if(ql <= mid) ans += query_cnt(ls(x), l, mid, ql, qr);
	if(qr > mid) ans += query_cnt(rs(x), mid + 1, r, ql, qr);
	return ans;
}
LL query_sum(int x, int l, int r, int ql, int qr) {
	if(l >= ql && r <= qr) {
		return tr[x].sum;
	}
	int mid = (l + r) >> 1;
	LL ans = 0;
	pushdown(x);
	if(ql <= mid) ans += query_sum(ls(x), l, mid, ql, qr);
	if(qr > mid) ans += query_sum(rs(x), mid + 1, r, ql, qr);
	return ans % mod;
}
int main() {
	int n;
	while(~scanf("%d", &n)) {
		for (int i = 1; i <= n; i++) {
			scanf("%d%d", &a[i].x, &a[i].y);
		}
		sort(a + 1, a + 1 + n, cmp1);
		for (int i = 1; i <= n; i++) {
			a[i].rank = i;
		}
		sort(a + 1, a + 1 + n, cmp2);
		build(1, 1, n);
		LL ans = 0;
		for (int i = 1; i <= n; i++) {
			LL tmp = query_cnt(1, 1, n, 1, a[i].rank);
			update_cnt(1, 1, n, a[i].rank, tmp, a[i].x);
			ans = (ans + query_sum(1, 1, n, a[i].rank, n) * qpow(3, a[i].y) % mod) % mod;
			if(a[i].rank != n)
				update_sum(1, 1, n, a[i].rank + 1, n);
		}
		printf("%lld
", ans);
	}
}

  

原文地址:https://www.cnblogs.com/pkgunboat/p/10550737.html