联赛模拟17_简单的区间(分治)

T1简单的区间

看到模数比较小,1e6的范围,可以开个数组,就有思路了

不同的区间max的位置不确定,我们考虑分治。对于每个区间,我们只计算跨过中点的区间贡献。

并且分2种情况,最大值在左边,最大值在右边。这样扫一边的时候,另一边的边界指针是单调的。

首先维护一个桶buc[]

我们以假设max在左边为例

假设此左指针 (i)(mid) 之间的 (sum)(s1)

(mid+1) 到右指针 (j) 之间的 (sum)(s2)

对于枚举的左指针,右指针移动的同时用 (buc[x]),记录(s2 mod k = x)(s2) 的个数

如果((s1-max) mod k == y) ,那么我们只需要找到有多少个(s2),与 (s1-max) 加起来能整除 (k) 即可(也就是找 (buc[(k-y)%k])

(ma) x在右边是一样的, (buc[]) 记录 (s1) ,对于 (s2) 去找合适的 (s1) 即可

注意:

1.(max) 在左边时,右指针 (j) 移动的判断 (max[i]>=max[j]), 那么避免重复,(max) 在右边,左指针 (i) 移动判断应为 (max[i]<max[j]) ,不再考虑取等

2.对于每个区间,(buc[]) 清空的时候不要 (for(i=1->k)) ,否则复杂度太高(2e6log),用个栈维护一下哪些该删就能保证复杂度为 (O(nlogn))

#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <cctype>
using namespace std;
char buf[1<<20],*p1,*p2;
#define rint register int
#define gc() (p1==p2?(p2=buf+fread(p1=buf,1,1<<20,stdin),p1==p2?EOF:*p1++):*p1++)
#define read() ({
	rint x=0;register bool f=0;register char ch=gc();
	while(!isdigit(ch)) f|=ch=='-',ch=gc();
	while(isdigit(ch)) x=(x<<3)+(x<<1)+(ch&15),ch=gc();
	f?-x:x;
})
const int maxn=3e5+5;
int n,k;
int a[maxn];
int Max[maxn];
int sta[maxn],top;
int buc[1000000+5];
int sum[maxn];
long long ans;
void solve(rint l,rint r){
	if(l==r) return;
	rint mid=(l+r)/2;
	solve(l,mid),solve(mid+1,r);
	while(top) buc[sta[top--]]=0;
	Max[mid]=a[mid],Max[mid+1]=a[mid+1];
	for(rint i=mid-1;i>=l;--i) Max[i]=max(a[i],Max[i+1]);
	for(rint i=mid+2;i<=r;++i) Max[i]=max(a[i],Max[i-1]);
	rint now,i=mid,j=mid+1;
	while(i>=l){
		while(j<=r&&Max[j]<=Max[i]){ // 这里 <=
			const rint res=(sum[j]-sum[mid]+k)%k;
			buc[res]++;
			sta[++top]=res;
			++j;
		}
		now=((sum[mid]-sum[i-1]-Max[i])%k+k)%k;
		ans+=buc[(k-now)%k]; // 一定要mod k,因为now会==0
		--i;
	}
	while(top) buc[sta[top--]]=0;
	i=mid,j=mid+1;
	while(j<=r){
		while(i>=l&&Max[i]<Max[j]){ //这里 < ,防止算重
			const rint res=(sum[mid]-sum[i-1]+k)%k;
			buc[res]++;
			sta[++top]=res;
			--i;
		}
		now=((sum[j]-sum[mid]-Max[j])%k+k)%k;
		ans+=buc[(k-now)%k];
		++j;
	}
}
int main(){
	freopen("interval.in","r",stdin);
	freopen("interval.out","w",stdout);
	n=read(),k=read();
	for(rint i=1;i<=n;++i) sum[i]=(sum[i-1]+(a[i]=read()))%k;
	solve(1,n);
	printf("%lld
",ans);
	return 0;
}

原文地址:https://www.cnblogs.com/Lour688/p/13821228.html