【题解】 「NOI2017」整数 线段树+二分+压位 LOJ2302

Legend

Link to LOJ

请维护一个高精二进制数 \(s\),支持操作 \(n\ (0 \le n \le 10^6)\) 次:

  • 加或减 \(a\times 2^{b}\)\((|a|\le 10^9,0\le b\le 30n)\)
  • 查询 \(s \operatorname{and} 2^b\) 的结果转化为 \(\textrm{bool}\) 后是否为真。

时空 \(\textrm{2s/512MB}\)

Editorial

作为 \(\textrm{NOI2017}\) 的第一题,必定是一道良心送温暖题,让我们一起为出题人松松松鼓掌。

brute

容易看到底下有一些部分分,映入眼帘的便是 \(|a|=1\),于是我们就想到把每一个加减操作看成 \(O(\log a)\) 次加减单个二进制位。怎么样?是不是看起来简单一点了?

考虑直接模拟。假设现在做加法的是位置 \(l\),如果这一位是 \(0\) 就直接改成 \(1\),否则即找到之后第一个为 \(0\) 的位置 \(p\ (l < p)\),把 \([l,p-1]\) 都改成 \(0\),并把位置 \(p\) 改成 \(1\)

减法同理。如果这一位是 \(1\) 就直接改成 \(0\),否则即找到之后第一个为 \(1\) 的位置 \(p\ (l < p)\),把 \([l,p-1]\) 都改成 \(1\),并把位置 \(p\) 改成 \(0\)

以上两个操作可以直接在线段树上二分找到,打区间覆盖标记。查询则可以直接使用线段树单点查询。

于是就得到了一个复杂度为 \(O(n \log n\log a)\) 的做法。

optimization

上述做法的瓶颈在于:

  • 数组长度是 \(30n\),凭空多出来一个常数。
  • 要进行拆位,\(1\) 个操作变成了 \(\log a\) 个。

不妨往反方向考虑,把数组压位,连续 \(32\) 个数字用一个 \(\textrm{unsigned int}\) 存储。

这样子对于一个修改操作我们最多只要拆成两个。而查询连续 \(1\) 段和连续 \(0\) 段依然可以用线段树实现,代码相差无几。

但这样就可以把复杂度优化到 \(O\left(\dfrac{n \log n \log a}{\omega}\right)\),其中 \(\omega\) 为压位大小。

Code

写的时候有点犯迷糊,最开始用 \(\textrm{unsigned int}\) 存了读入的 \(a\),后来又没写线段树的 \(\textrm{pushup pushdown}\),最后发现线段树二分写错了……白白浪费了一个下午+晚上。

就这样修修补补写出了下面这些东西,有点繁琐了,但还可以看。

LOJ 上这破烂可以在 \(\textrm{800ms}\) 内跑过。

// Author : Imakf
#include <bits/stdc++.h>

using namespace std;

#define LL long long
#define debug(...) fprintf(stderr ,__VA_ARGS__)
#define __FILE(x)\
	freopen(#x".in" ,"r" ,stdin);\
	freopen(#x".out" ,"w" ,stdout)

LL read(){
	char k = getchar(); LL x = 0 ,flg = 1;
	while(k < '0' || k > '9')
		flg *= k == '-' ? -1 : 1 ,k = getchar();
	while(k >= '0' && k <= '9')
		x = x * 10 + k - '0' ,k = getchar();
	return x * flg;
}


const int MX = 1e6 + 233;

struct node{
	int l ,r ,c;
	unsigned int num;
	bool zero ,all ,cov;
	node *lch ,*rch;
}*root;

void pushup(node *x){
	x->zero = x->lch->zero & x->rch->zero;
	x->all  = x->lch->all  & x->rch->all;
}

node *build(int l ,int r){
	node *x = new node;
	x->l       = l;
	x->r       = r;
	x->zero    = true;
	x->all     = false;
	x->cov     = false;
	x->c       = 0;
	x->num     = 0;
	if(l == r){
		x->lch = nullptr;
		x->rch = nullptr;
	}
	else{
		int mid = (l + r) >> 1;
		x->lch = build(l ,mid);
		x->rch = build(mid + 1 ,r);
		pushup(x);
	}return x;
}
void docov(node *x ,bool v){
	x->cov  = true;
	x->c    = v;
	x->zero = !v;
	x->all  = v;
	x->num  = v ? UINT_MAX : 0;
}
void pushdown(node *x){
	if(x->cov){
		x->cov = false;
		docov(x->lch ,x->c);
		docov(x->rch ,x->c);
	}
}

void cov(node *x ,int l ,int r ,bool val){
	if(l <= x->l && x->r <= r) return docov(x ,val);
	pushdown(x);
	if(l <= x->lch->r) cov(x->lch ,l ,r ,val);
	if(r > x->lch->r) cov(x->rch ,l ,r ,val);
	return pushup(x);
}

void add(node *x ,LL v){
	x->num += v;
	x->all  = x->num == UINT_MAX;
	x->zero = x->num == 0;
}

int __add(node *x ,int l ,int r){ // 找到最小的不是全 1 的 pos
	if(x->r < l || x->l > r) return 0;
	if(x->all) return 0;
	if(x->l == x->r){
		return add(x ,1) ,x->l;
	}
	pushdown(x);
	int ret = 0;
	if(x->lch->all || !(ret = __add(x->lch ,l ,r))){
		ret = __add(x->rch ,l ,r);
	}
	pushup(x);
	return ret;
}

void add(node *x ,int p ,LL val){
	if(x->l == x->r){
		if(x->num + val > UINT_MAX){
			x->num = (x->num + val) & UINT_MAX;
			add(x ,0);
			int pos = __add(root ,p + 1 ,MX);
			if(pos - 1 >= p + 1) cov(root ,p + 1 ,pos - 1 ,0);
		}
		else add(x ,val);
		return ;
	}
	pushdown(x);
	if(p <= x->lch->r) add(x->lch ,p ,val);
	else add(x->rch ,p ,val);
	return pushup(x);
}

void add(LL a ,LL b){
	// add a*(2^b)
	int bit32 = b / 32 ,bit = b % 32;
	LL f = a << bit;
	if(f > UINT_MAX){
		add((f & UINT_MAX) >> bit ,b);
		add(f >> 32 ,(bit32 + 1) * 32);
		return ;
	}
	// debug("%lld %lld\n" ,a ,b);
	add(root ,bit32 ,f);
}

int __del(node *x ,int l ,int r){ // 找到最小的不是全 0 的 pos
	// debug("Find [%d ,%d] ,allzero = %d\n" ,x->l ,x->r ,x->zero);
	if(x->r < l || x->l > r) return 0;
	if(x->zero) return 0;
	if(x->l == x->r){
		return add(x ,-1) ,x->l;
	}
	pushdown(x);
	int ret = 0;
	if(x->lch->zero || !(ret = __del(x->lch ,l ,r))){
		ret = __del(x->rch ,l ,r);
	}
	pushup(x);
	return ret;
}

void del(node *x ,int p ,LL val){
	if(x->l == x->r){
		if(x->num - val < 0){
			x->num = x->num - val + UINT_MAX + 1;
			add(x ,0);
			int pos = __del(root ,p + 1 ,MX);
			if(pos - 1 >= p + 1) cov(root ,p + 1 ,pos - 1 ,1);
		}
		else add(x ,-val);
		return ;
	}
	pushdown(x);
	if(p <= x->lch->r) del(x->lch ,p ,val);
	else del(x->rch ,p ,val);
	return pushup(x);
}

void sub(LL a ,LL b){
	int bit32 = b / 32 ,bit = b % 32;
	LL f = a << bit;
	if(f > UINT_MAX){
		sub((f & UINT_MAX) >> bit ,b);
		sub(f >> 32 ,(bit32 + 1) * 32);
		return ;
	}
	del(root ,bit32 ,f);
}

LL query(node *x ,int p){
	if(x->l == x->r) return x->num;
	pushdown(x);
	if(p <= x->lch->r) return query(x->lch ,p);
	return query(x->rch ,p);
}

int query(int pos){
	int bit32 = pos / 32 ,bit = pos % 32;
	return (query(root ,bit32) >> bit) & 1;
}

void output(node *x){
	if(x->l == x->r){
		for(int i = 0 ; i < 32 ; ++i){
			debug("%u" ,(x->num >> i) & 1);
		}
		return;
	}
	pushdown(x);
	output(x->lch) ,output(x->rch);
}

int main(){
	__FILE([NOI2017]整数);
	
	int n = read(); read() ,read() ,read();
	root = build(0 ,MX);
	for(LL i = 1 ,op ,a ,b ; i <= n ; ++i){
		// debug("%d\n" ,i);
		op = read();
		if(op == 1){
			a = read() ,b = read();
			// assert(a >= 0);
			if(a > 0) add(a ,b);
			else sub(-a ,b);
		}
		else{
			a = read();
			printf("%d\n" ,query(a));;
		}
		// output(root);
		// debug("\n");
	}
}
原文地址:https://www.cnblogs.com/imakf/p/13623627.html