codeforces 1428F

题目链接:https://codeforces.com/problemset/problem/1428/F

(pre[i]) 表示从 i 位置开始向前最长的 1 的连续序列的长度
(mx[l,r]) 表示 ([l,r]) 内最长 1 的连续序列的长度
从前向后扫描,每次扫到一个新的 i ,考虑当前位置对之前所有序列的贡献,即

[sum_{L=1}^{n} mx[L,i] ]

考虑更新(mx),不难发现,(mx_{L=1}^{i}[L,i])是单调递减的,只需要找到满足 (mx[L,i] < pre[i]) 的最小的 L 的位置,
将[L,i]区间内的 (mx) 都 +1 ,线段树上二分找位置 + 区间更新即可 (inline卡常)

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<cmath>
#include<stack>
#include<queue>
using namespace std;
typedef long long ll;

const int maxn = 500010;
const int INF = 1000000007;

int n,m;
ll ans;
int a[maxn],pre[maxn];
char s[maxn];

struct SEG{
	int add,sum,mx,mi;
}t[maxn<<2];

inline void pushup(int i){
	t[i].sum = t[i<<1].sum + t[i<<1|1].sum;
	t[i].mi = min(t[i<<1].mi, t[i<<1|1].mi);
}

inline void pushdown(int i,int l,int r){
	if(t[i].add){
		t[i<<1].add += t[i].add;
		t[i<<1|1].add += t[i].add;
		t[i<<1].mi += t[i].add;
		t[i<<1|1].mi += t[i].add;
		int mid=(l+r) >> 1;
		t[i<<1].sum += (mid-l+1) * t[i].add;
		t[i<<1|1].sum += (r-mid) * t[i].add;
		t[i].add = 0;
	}
}

inline void mdf(int i,int k,int l,int r,int x,int y){
	if(x<=l && r<=y){
		t[i].sum += (r-l+1) * k;
		t[i].add += k;
		t[i].mi += k;
		return;
	}
	pushdown(i,l,r);
	int mid=(l+r)>>1;
	if(x<=mid) mdf(i<<1,k,l,mid,x,y);
	if(y>mid) mdf(i<<1|1,k,mid+1,r,x,y);
	pushup(i);
}

inline int query_min(int i,int l,int r,int x,int y){
	if(x<=l && r<=y){
		return t[i].mi;
	}
	pushdown(i,l,r);
	int res = INF;
	int mid = (l+r) >> 1;
	if(x<=mid) res = min(res,query_min(i<<1,l,mid,x,y));
	if(y>mid) res = min(res,query_min(i<<1|1,mid+1,r,x,y));
	return res;
}

inline int fin(int i,int k,int l,int r){
	if(l==r){
		return l;
	}

	pushdown(i,l,r);
	int mid = (l+r)>>1;
	if(query_min(1,1,n,l,mid) < k) return fin(i,k,l,mid);
	else return fin(i,k,mid+1,r);
}

ll read(){ ll s=0,f=1; char ch=getchar(); while(ch<'0' || ch>'9'){ if(ch=='-') f=-1; ch=getchar(); } while(ch>='0' && ch<='9'){ s=s*10+ch-'0'; ch=getchar(); } return s*f; }

int main(){
	n = read();
	ans = 0;
	scanf("%s",s+1);
	
	for(register int i=1;i<=n;++i){
		if(s[i] == '0') a[i] = 0;
		else a[i] = 1;
	}

	ll sum = 0;
	for(register int i=1;i<=n;++i){
		if(a[i] == 1){
			pre[i] = pre[i-1] + 1;
			int pos = fin(1,pre[i],1,i);
			mdf(1,1,1,n,pos,i);
			sum += i - pos + 1;
//			printf("i:%d %d
",i,pos);
		}else{
			pre[i] = 0;
		}
		ans += sum;
	}
	printf("%lld
",ans);
	
	return 0;
}
原文地址:https://www.cnblogs.com/tuchen/p/13874091.html