LOJ3177「IOI2019」矩形区域【转化条件,计数】

给定 (n imes m) 的自然数矩阵 (a),求有多少个 ((r_1,r_2,c_1,c_2)) 满足 (1le r_1le r_2le n-2)(1le c_1le c_2le m-2)(forall iin[r_1,r_2],jin[c_1,c_2])(a_{i,j}<min(a_{i,c_1-1},a_{i,c_2+1},a_{r_1-1,j},a_{r_2+1,j}))

(n,mle 2500)(a_{i,j}le 7cdot 10^6)


首先可以发现,充要条件就是 (forall iin[r_1,r_2])(a_{i,c_1-1}) 右边第一个不小于它的值是 (a_{i,c_2+1})(a_{i,c_2+1}) 左边第一个不小于它的值是 (a_{i,c_1-1}),列同理。

对于每一行,满足条件的 ((c_1,c_2)) 可以用单调栈算出来,并且只有至多 (2m) 个。列同理。

考虑如何计算,先对每行算一遍,把所有 ((c_1,c_2)) 对应的行标号记录下来,然后从小到大枚举右边界 (c_2),把这一列算一遍,维护所有 ((r_1,r_2)) 对应的当前最右连续区间 ([lb_{r_1,r_2},rb_{r_1,r_2}]),然后枚举左边界 (c_1),遍历 ((c_1,c_2)) 对应的行标号连续段。

现在已经去掉了行限制,至于列限制就直接对其中一列做一遍,可能满足条件的 ((r_1,r_2)) 只有至多 (2n) 个,判断就看 ([lb_{r_1,r_2},rb_{r_1,r_2}]) 是否包含 ([c_1,c_2])

时间复杂度 (O(nm))

#include<bits/stdc++.h>
#define PB emplace_back
#define fi first
#define se second
using namespace std;
typedef pair<int, int> pii;
const int N = 2502, K = 5e7;
char buf[K], *in = buf;
int read(){
	int x = 0;
	for(;!isdigit(*in);++ in);
	for(;isdigit(*in);++ in) x = x * 10 + *in - '0';
	return x;
}
int n, m, a[N][N], lb[N][N], rb[N][N], stk[N], tp, ans;
vector<int> ok[N][N];
vector<pii> res;
void work(int *b, int l){
	res.resize(tp = 0);
	for(int i = 1;i <= l;++ i){
		while(tp && b[i] > b[stk[tp]]){
			if(i > stk[tp]+1) res.PB(stk[tp]+1, i-1);
			-- tp;
		}
		if(tp){
			if(i > stk[tp]+1) res.PB(stk[tp]+1, i-1);
			if(b[i] == b[stk[tp]]) -- tp;
		}
		stk[++tp] = i;
	}
}
void calc(int l, int r, int u, int d){
	for(int i = u-1;i <= d+1;++ i)
		a[0][i-u+2] = a[i][l];
	work(*a, d-u+3);
	for(pii p : res){
		int L = p.fi+u-2, R = p.se+u-2;
		ans += (lb[L][R] <= l && r <= rb[L][R]);
	}
}
int main(){
	fread(buf, 1, K, stdin);
	n = read(); m = read();
	for(int i = 1;i <= n;++ i)
		for(int j = 1;j <= m;++ j)
			a[i][j] = read();
	for(int i = 2;i < n;++ i){
		work(a[i], m);
		for(pii p : res) ok[p.fi][p.se].PB(i);
	}
	for(int r = 2;r < m;++ r){
		for(int i = 1;i <= n;++ i)
			a[0][i] = a[i][r];
		work(*a, n);
		for(pii p : res){
			if(rb[p.fi][p.se] < r-1) lb[p.fi][p.se] = r;
			rb[p.fi][p.se] = r;
		}
		for(int l = 2;l <= r;++ l){
			int len = ok[l][r].size();
			if(!len) continue;
			int lst = ok[l][r][0];
			for(int i = 1;i < len;++ i)
				if(ok[l][r][i] > ok[l][r][i-1]+1){
					calc(l, r, lst, ok[l][r][i-1]);
					lst = ok[l][r][i];
				}
			calc(l, r, lst, ok[l][r][len-1]);
		}
	}
	printf("%d
", ans);
}
原文地址:https://www.cnblogs.com/AThousandMoons/p/14882282.html