[题解] 面包

题意

给定一个 (n imes m) 的网格,其中有 (k) 个关键点,求所有至少含有一个关键点的子矩形所含关键点数的方差。

(n, m le 10^9, k le 2000)

思路

(s_0, s_1, s_2) 为所有合法子矩形所含关键点数的 (0, 1, 2) 次和,容易推出方差与 (s_{0, 1, 2}) 的关系。

(s_n (n > 0)) 容易计算,考虑 (x^n) 的组合意义就是有序可重复地取出 (n) 个点,因此枚举 (n) 个点,计算包含它们的矩形数量即可,复杂度 (mathcal O(n^2))

比较难算的是 (s_0)考虑对矩形中的唯一的一个点进行计数,这里计数以 (x) 为第一关键字,(y) 为第二关键字排序的最后一个点。

先对点排序,下文用 ([x_l, x_r], [y_l, y_r]) 表示一个子矩形,按顺序枚举点 ((x, y))

  • 显然 (x_l in [1, x]) 都是合法的。

  • 依次扫描位于 ((x, y)) 之后的点,这些点不能被子矩形覆盖,因此每个点会对 (y_l, y_r) 中的一个加以限制,每加入一个点计算器 ([x_i, x_{i+1})) 贡献即可。

复杂度 (mathcal O(n^2))

代码

#include <cstdio>
#include <algorithm>
using namespace std;
#define File(s) freopen(s".in", "r", stdin), freopen(s".out", "w", stdout)

const int mod = 998244353;
inline int add(int x, int y) {return x+y>=mod ? x+y-mod : x+y;}
inline int sub(int x, int y) {return x-y<0 ? x-y+mod : x-y;}
inline int mul(int x, int y) {return 1LL * x * y % mod;}
inline void inc(int &x, int y=1) {x += y; if(x >= mod) x -= mod;}
inline void dec(int &x, int y=1) {x -= y; if(x < 0) x += mod;}
inline int power(int x, int y){
  int res = 1;
  for(; y; y>>=1, x = mul(x, x)) if(y & 1) res = mul(res, x);
  return res;
}
inline int inv(int x){return power(x, mod - 2);}
template<class T> void upmax(T &x, T y){x = x>y ? x : y;}
template<class T> void upmin(T &x, T y){x = x<y ? x : y;}

const int N = 2005;
struct Pt{
  int x, y;
}a[N];

int main(){
  int n, m, k;
  scanf("%d%d%d", &n, &m, &k);
  for(int i=1; i<=k; i++)
    scanf("%d%d", &a[i].x, &a[i].y);
  sort(a + 1, a + 1 + k, [](Pt x, Pt y){
    if(x.x == y.x) return x.y < y.y;
    return x.x < y.x;
  });
  int s0 = 0, s1 = 0, s2 = 0;
  for(int i=1; i<=k; i++)
    inc(s1, mul(mul(a[i].x, n - a[i].x + 1), mul(a[i].y, m - a[i].y + 1)));
  for(int i=1; i<=k; i++)
    for(int j=1; j<=k; j++){
      int xl = min(a[i].x, a[j].x), xr = max(a[i].x, a[j].x);
      int yl = min(a[i].y, a[j].y), yr = max(a[i].y, a[j].y);
      inc(s2, mul(mul(xl, n - xr + 1), mul(yl, m - yr + 1)));
    }
  a[k + 1].x = a[k].x;
  for(int i=1; i<=k; i++){
    int yl = 0, yr = m + 1;
    int now = mul(a[i + 1].x - a[i].x, mul(a[i].y, m - a[i].y + 1));
    for(int j=i+1; j<=k; j++){
      if(a[j].y <= a[i].y) upmax(yl, a[j].y);
      if(a[j].y >= a[i].y) upmin(yr, a[j].y);
      if(a[j].x != a[j + 1].x)
        inc(now, mul(mul(yr - a[i].y, a[i].y - yl), a[j + 1].x - a[j].x));
    }
    inc(now, mul(mul(yr - a[i].y, a[i].y - yl), n - a[k].x + 1));
    inc(s0, mul(now, a[i].x));
  }
  int avg = mul(s1, inv(s0));
  printf("%d
", add(mul(sub(s2, mul(2, mul(s1, avg))), inv(s0)), mul(avg, avg)));
  return 0;
}
原文地址:https://www.cnblogs.com/RiverHamster/p/sol-oj3058.html