[BJOI2018]二进制

题目链接

题目描述:

pupil 发现对于一个十进制数,无论怎么将其的数字重新排列,均不影响其是不是 的倍数。他想研究对于二进制,是否也有类似的性质。
于是他生成了一个长为n 的二进制串,希望你对于这个二进制串的一个子区间,能求出其有多少位置不同的连续子串,满足在重新排列后(可包含前导0)是一个3 的倍数。
两个位置不同的子区间指开始位置不同或结束位置不同。
由于他想尝试尽量多的情况,他有时会修改串中的一个位置,并且会进行多次询问。

首先,看到区间的子串,想到线段树区间合并。
考虑一个子问题:判断一个子串是否满足条件。
首先,二进制化为十进制的方法是对应位×位权。
可以发现,位权(mod3)的值是(1,2,1,2)交替。
设这个子串有x个1放在2上,y个放在1上,则应满足((2x+y)mod3=0),即(x=y(mod3))
设该子串长度为len,有s个1。
那么,贪心的想,x和y相差0或3。
分情况讨论:

  1. (len\%2=0)(s\%2=0)。需要满足(s/2<=len/2),显然可以。
  2. (len\%2=0)(s\%2=1)。即(x+y=s,x-y=3),解得(x=(s+3)/2)。需要满足((s+3)/2<=len/2),即(s<=len-3)
  3. (len\%2=1)(s\%2=0)。需要满足(s/2<=(len-1)/2),显然可以。
  4. (len\%2=1)(s\%2=1)。同理,(x=(s+3)/2)。需要满足((s+3)/2<=(len+1)/2),即(s<=len-2)

可以发现,s为偶数时一定可以。s为奇数时比较复杂,难以统计(因为即使len确定,满足要求的s有很多,无法进行区间合并)。
考虑容斥,求不合法的。
若不合法,则(s\%2=1)
分两种情况讨论:

  1. (len\%2=0),即(s>=len-2)。因为(s\%2=1,s<=len),所以(s=len-1)
  2. (len\%2=1),即(s>=len-1)。因为(s\%2=1,s<=len),所以(s=len)

此外,若(s=1),也不合法。把这部分减去,再加上(s=1)(len<=2)的情况(被算了两次)。
时间复杂度:(O(mlogn)),常数很大。
注意区间合并的细节。

代码:

#include <stdio.h> 
#define ll long long 
struct SJd {
	ll s[2][2];
	SJd() {
		s[0][0] = s[0][1] = s[1][0] = s[1][1] = 0;
	}
};
SJd operator * (SJd a, SJd b) {
	SJd rt;
	for (int x = 0; x < 2; x++) {
		for (int y = 0; y < 2; y++) {
			int z = (x + y) % 2;
			rt.s[z][0] += a.s[x][0] * b.s[y][0];
			rt.s[z][1] += a.s[x][0] * b.s[y][1] + a.s[x][1] * b.s[y][0];
		}
	}
	return rt;
}
SJd operator + (SJd a, SJd b) {
	SJd rt;
	rt.s[0][0] = a.s[0][0] + b.s[0][0];
	rt.s[0][1] = a.s[0][1] + b.s[0][1];
	rt.s[1][0] = a.s[1][0] + b.s[1][0];
	rt.s[1][1] = a.s[1][1] + b.s[1][1];
	return rt;
}
SJd he[400010],zg[400010],zl[400010],zr[400010];
ll s0[400010],s1[400010],s2[400010];
bool h0[400010],h1[400010];
int sz[100010],l0[400010],l1[400010],r0[400010],r1[400010],su[400010],ma = 0,sl;
void fuz(int i, int j) {
	he[i]=he[j];zg[i]=zg[j];
	zl[i]=zl[j];zr[i]=zr[j];
	s0[i]=s0[j];s1[i]=s1[j];s2[i]=s2[j];
	h0[i]=h0[j];h1[i]=h1[j];
	l0[i]=l0[j];l1[i]=l1[j];
	r0[i]=r0[j];r1[i]=r1[j];
	su[i]=su[j];
}
void merge(int i, int cl, int cr, int m) {
	h0[i] = h0[cl] && h0[cr];
	h1[i] = (h1[cl] && h0[cr]) || (h0[cl] && h1[cr]);
	l0[i] = l0[cl];
	if (h0[cl]) l0[i] += l0[cr];
	r0[i] = r0[cr];
	if (h0[cr]) r0[i] += r0[cl];
	l1[i] = l1[cl];
	if (h0[cl]) l1[i] += l1[cr];
	if (h1[cl]) l1[i] += l0[cr];
	r1[i] = r1[cr];
	if (h0[cr]) r1[i] += r1[cl];
	if (h1[cr]) r1[i] += r0[cl];
	s0[i] = s0[cl] + s0[cr] + 1ll * r0[cl] * l0[cr];
	s1[i] = s1[cl] + s1[cr] + 1ll * r0[cl] * l1[cr] + 1ll * r1[cl] * l0[cr];
	s2[i] = s2[cl] + s2[cr] + (sz[m - 1] + sz[m] == 1);
	zg[i] = zg[cl] * zg[cr];
	zl[i] = zl[cl] + zg[cl] * zl[cr];
	zr[i] = zr[cr] + zg[cr] * zr[cl];
	he[i] = he[cl] + he[cr] + zr[cl] * zl[cr];
	su[i] = su[cl] + su[cr];
}
void pushup(int i, int l, int r) {
	int m = (l + r) >> 1;
	merge(i, i << 1, (i << 1) | 1, m);
}
void getddz(int i, int x) {
	l0[i] = r0[i] = s0[i] = h0[i] = (x == 0);
	l1[i] = r1[i] = s1[i] = h1[i] = (x == 1);
	he[i] = zg[i] = zl[i] = zr[i] = SJd();
	he[i].s[1][x ^ 1] += 1;
	zg[i].s[1][x ^ 1] += 1;
	zl[i].s[1][x ^ 1] += 1;
	zr[i].s[1][x ^ 1] += 1;
	su[i] = x;
	s2[i] = 0;
}
void jianshu(int i, int l, int r) {
	if (i > ma) ma = i;
	if (l + 1 == r) {
		getddz(i, sz[l]);
		return;
	}
	int m = (l + r) >> 1;
	jianshu(i << 1, l, m);
	jianshu((i << 1) | 1, m, r);
	pushup(i, l, r);
}
void xiugai(int i, int l, int r, int j) {
	if (l + 1 == r) {
		sz[l] ^= 1;
		getddz(i, sz[l]);
		return;
	}
	int m = (l + r) >> 1;
	if (j < m) xiugai(i << 1, l, m, j);
	else xiugai((i << 1) | 1, m, r, j);
	pushup(i, l, r);
}
void getans(int i, int l, int r, int L, int R) {
	if (R <= l || r <= L) return;
	if (L <= l && r <= R) {
		if (sl == ma) fuz(sl + 1, i);
		else merge(sl + 1, sl, i, l);
		sl += 1;
		return;
	}
	int m = (l + r) >> 1;
	getans(i << 1, l, m, L, R);
	getans((i << 1) | 1, m, r, L, R);
}
ll getans(int n, int l, int r) {
	int s = r - l + 1;
	ll zs = 1ll * s * (s + 1) / 2;
	sl = ma;
	getans(1, 1, n + 1, l, r + 1);
	zs -= (he[sl].s[0][1] + he[sl].s[1][0] + s1[sl]);
	zs += (su[sl] + s2[sl]);
	return zs;
}
int main() {
	int n,m;
	scanf("%d", &n);
	for (int i = 1; i <= n; i++) scanf("%d", &sz[i]);
	scanf("%d", &m);
	jianshu(1, 1, n + 1);
	for (int i = 0; i < m; i++) {
		int lx;
		scanf("%d", &lx);
		if (lx == 1) {
			int a;
			scanf("%d", &a);
			xiugai(1, 1, n + 1, a);
		} else {
			int a,b;
			scanf("%d%d", &a, &b);
			printf("%lld
", getans(n, a, b));
		}
	}
	return 0;
}
原文地址:https://www.cnblogs.com/lnzwz/p/11348249.html