【题解】 CF1404C Fixed Point Removal 线段树+树状数组+带悔贪心

Legend

Link ( extrm{to Codeforces})

给定长度为 (n (1 le n le 3 imes 10^5)) 的数组 (a_i (1 le a_i le n))

每次你可以选择删除一个位置上的数字当且仅当它的下标等于数字本身,即 (a_i=i)

删除后数组后面一段会平移过来,改变下标。

(q (1 le q le 3 imes 10^5)) 组询问,给出 (x,y (0 le x,y)(x+y < n)) 询问:

前面 (x) 个强制不能删除,后面 (y) 个强制不能删除,最多可以删掉多少个数字。

Editorial

Inspiration

考虑只询问一次怎么做?每次从右边找起,能删就删除。但这样是 (O(qn^2)) 的。

题目中提到的查询是无视一个前缀和一个后缀,无视后缀应该很好做,可能随便减减就能解决。

无视前缀则留下了一个后缀,这不禁让人想到预处理每一个后缀的答案。

所以我们从右往左依次加入数字,这样每次都考虑的是一个后缀。

假设单独选出 ([l,n]) 区间,最初位置在 (x) 上的数字可以被删掉,那么显然,选出 ([l-1,n]) 的时候也能被删掉。

所以我们每加入一个新的数字就去检查有哪些数字可以被删去,这个怎么快速找呢?

optimization

我们维护一个初始的 (v_i=i-a_i)

  • (v_i < 0) 的都一定不能被删除,因为数字只能前移;
  • (v_i =0) 的是可以被删除的;
  • (v_i > 0) 是潜在的可能被以后删除的。

每次只要找到一个 (v_x=0) 的位置,删除它即可。

删除一个位于 (x) 的数字之后,我们就手动把 (v_i (i in [x,n])) 全部 (-1),表示前移一位。

发现我们可以用线段树维护,找到最右侧的 (v_i=0) 的位置,这样就保证不会把其他 (v_j=0) 的位置破坏掉。

那么这样直到最后我们对于每一个位置上的数字,都可以得到一个二元组 ((suf_i ,id_i))

(suf_i) 表示的是这个数字是在第几个位置上的数字被加进来之后才删掉的。

(id_i) 表示这个位置的最初下标。

考虑对于每一组 ((x,y)) 的询问,我们要求什么,实际上是要求形如 ((i,j)) 且同时满足 (i ge x+1)(j le n-y) 的二元组数量。

这个离线之后就是树状数组板子了。

Code

我在代码中,并不是维护的 (v_i=0) 最靠右的位置,而是选择了一个 (v_i<0) 的位置,这样也是可以通过的。

为什么呢?可以类比带悔贪心的思路呀,一定是可以通过改变删除顺序使得这个位置也能被删除的。

#include <bits/stdc++.h>

#define LL long long

const int MX = 3e5 + 233;

using namespace std;

int read(){
	char k = getchar(); int x = 0;
	while(k < '0' || k > '9') k = getchar();
	while(k >= '0' && k <= '9') x = x * 10 + k - '0' ,k = getchar();
	return x;
}

int a[MX];

struct node{
	int l ,r ,mn ,mnfr ,add;
	node *lch ,*rch;
	node operator +(node B)const{
		node C;
		C.mn = min(this->mn ,B.mn);
		C.mnfr = (C.mn == this->mn) ? this->mnfr : B.mnfr;
		return C;
	}
}*root;

void pushup(node *x){
	x->mn = min(x->lch->mn ,x->rch->mn);
	x->mnfr = x->lch->mn == x->mn ? x->lch->mnfr : x->rch->mnfr;
}

void doadd(node *x ,int v){x->mn += v ,x->add += v;}
void pushdown(node *x){
	if(x->add){
		doadd(x->lch ,x->add);
		doadd(x->rch ,x->add);
		x->add = 0;
	}
}

node *build(int l ,int r ,int *__){
	node *x = new node; x->l = l ,x->r = r; x->add = 0;
	if(l == r) x->lch = x->rch = nullptr ,x->mn = __[l] ,x->mnfr = l;
	else{int mid = (l + r) >> 1;
		x->lch = build(l ,mid ,__);
		x->rch = build(mid + 1 ,r ,__);
		pushup(x);
	}return x;
}

void add(node *x ,int l ,int r ,int v){
	if(l <= x->l && x->r <= r) return doadd(x ,v);
	pushdown(x);
	if(l <= x->lch->r) add(x->lch ,l ,r ,v);
	if(r > x->lch->r) add(x->rch ,l ,r ,v);
	return pushup(x);
}

node query(node *x ,int l ,int r){
	if(l <= x->l && x->r <= r) return *x;
	pushdown(x);
	if(l <= x->lch->r && r > x->lch->r) return query(x->lch ,l ,r) + query(x->rch ,l ,r);
	if(l <= x->lch->r) return query(x->lch ,l ,r);
	return query(x->rch ,l ,r);
}

void change(node *x ,int p ,int v){
	if(p <= x->l && x->r <= p) return x->mn = v ,void();
	pushdown(x);
	if(p <= x->lch->r) change(x->lch ,p ,v);
	if(p > x->lch->r) change(x->rch ,p ,v);
	return pushup(x);
}

int pcnt;
struct Point{
	int x ,y ,type ,coef ,id;
	bool operator <(const Point &B)const{
		return x == B.x ? y == B.y ? type < B.type : y < B.y : x < B.x;
	}
}P[MX * 3];

class BIT{
	private:
		int data[MX];
	public:
		void add(int x ,int v){while(x < MX) data[x] += v ,x += x & -x;}
		int sum(int x){int s = 0; while(x > 0) s += data[x] ,x -= x & -x; return s;}
}C;

int Ans[MX];
int main(){
	int n = read() ,q = read();
	for(int i = 1 ; i <= n ; ++i){
		a[i] = read();
		a[i] = (a[i] > i ? INT_MAX : i - a[i]);
	}
	root = build(1 ,n ,a);
	for(int i = n ; i ; --i){
		while(true){
			node kksk = query(root ,i ,n);
			if(kksk.mn <= 0){
				P[++pcnt] = (Point){i ,kksk.mnfr ,0 ,0 ,0};
				// fprintf(stderr ,"(%d ,%d)
" ,i ,kksk.mnfr);
				add(root ,kksk.mnfr ,n ,-1);
				change(root ,kksk.mnfr ,INT_MAX);
			}else break;
		}
	}
	for(int i = 1 ,x ,y ; i <= q ; ++i){
		x = read() ,y = read();
		P[++pcnt] = (Point){n ,n - y ,1 ,1 ,i};
		P[++pcnt] = (Point){x ,n - y ,1 ,-1 ,i};
	}
	sort(P + 1 ,P + 1 + pcnt);
	for(int i = 1 ; i <= pcnt ; ++i){
		if(P[i].type == 0){
			C.add(P[i].y ,1);
		}else{
			Ans[P[i].id] += P[i].coef * C.sum(P[i].y);
		}
	}
	for(int i = 1 ; i <= q ; ++i)
		printf("%d
" ,Ans[i]);
	return 0;
}
原文地址:https://www.cnblogs.com/imakf/p/13705155.html