[校内训练2021_03_16]B 矩阵竟然能分治

题目大意:有一个n*m的01矩阵,统计出框架的个数。框架的定义:一个四元组(L,U,D,R)即一个矩形,满足L<R,U<D,并且四条边上全都是0。

思考:我们考虑分治(虽然我不知道为什么能想到这个):我们先把原本的矩阵竖着切一刀,那么我们要分别计算出左右两边的数组f[l][r],表示能够跨过分界线的半个框架的个数。我们不难用数据结构得到一个n^2log^2n的做法(分治有一个log,区间求和一个log)。

但我们注意到,这道题存在特殊的偏序关系,即一个横着的框架如果能和另外一个横着的框架拼起来,那么我们只统计比较短的那一个。因此我们可以枚举上边界,往下或者上拓展,如果遇到一个更短的边界,那么在f[l][r]上加上相应的答案。

最后切矩阵的时候要交替地切。

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 typedef long long int ll;
  4 const int maxn=3505;
  5 ll ans;
  6 bool a[maxn][maxn],b[maxn][maxn];
  7 int U[maxn][maxn],D[maxn][maxn],L[maxn][maxn],R[maxn][maxn];
  8 int vis[maxn],g[maxn];
  9 inline void getL(int n,int m,int f[maxn][maxn])
 10 {
 11     for(int i=1;i<=n;++i)
 12         for(int j=1;j<=n;++j)
 13             f[i][j]=0;
 14     for(int i=0;i<=n+1;++i)
 15         b[i][0]=b[i][m+1]=1;
 16     for(int i=0;i<=m+1;++i)
 17         b[0][i]=b[n+1][i]=1;
 18     for(int i=0;i<=n+1;++i)
 19         for(int j=0;j<=m+1;++j)
 20             if(b[i][j])
 21                 L[i][j]=j,U[i][j]=i;
 22             else
 23                 L[i][j]=L[i][j-1],U[i][j]=U[i-1][j];
 24     for(int i=n+1;i>=0;--i)
 25         for(int j=m+1;j>=0;--j)
 26             if(b[i][j])
 27                 R[i][j]=j,D[i][j]=i;
 28             else
 29                 R[i][j]=R[i][j+1],D[i][j]=D[i+1][j];
 30     for(int i=1;i<=n;++i)
 31     {
 32         vis[i]=0;
 33         for(int j=1;j<=n;++j)
 34             g[j]=0;
 35         for(int j=1;j<R[i][1];++j)
 36             ++g[i],--g[D[i][j]];
 37         for(int j=i+1;j<=n;++j)
 38         {
 39             g[j]+=g[j-1];
 40             if(R[j][1]>=R[i][1])
 41                 f[i][j]+=g[j];
 42         }
 43         g[i]=0;
 44         for(int j=1;j<R[i][1];++j)
 45             ++g[i],--g[U[i][j]];
 46         for(int j=i-1;j>=1;--j)
 47         {
 48             g[j]+=g[j+1];
 49             if(R[j][1]>R[i][1])// !!!!!!!
 50                 f[j][i]+=g[j];
 51         }
 52     }
 53 }
 54 inline void getU(int n,int m,int f[maxn][maxn])
 55 {
 56     if(n<=m)
 57         for(int i=1;i<=n;++i)
 58             for(int j=i;j<=m;++j)
 59                 swap(b[i][j],b[j][i]);
 60     else
 61         for(int i=1;i<=n;++i)
 62             for(int j=1;j<=min(i,m);++j)
 63                 swap(b[i][j],b[j][i]);
 64     getL(m,n,f);
 65 }
 66 inline void getR(int n,int m,int f[maxn][maxn])
 67 {
 68     for(int i=1;i<=n;++i)
 69         for(int j=1;j<=m/2;++j)
 70             swap(b[i][j],b[i][m-j+1]);
 71     getL(n,m,f);
 72 }
 73 inline void getD(int n,int m,int f[maxn][maxn])
 74 {
 75     for(int i=1;i<=n/2;++i)
 76         for(int j=1;j<=m;++j)
 77             swap(b[i][j],b[n-i+1][j]);
 78     getU(n,m,f);
 79 }
 80 int f1[maxn][maxn],f2[maxn][maxn];
 81 inline void solve(int l,int r,int u,int d)
 82 {
 83     if(l>=r||u>=d)
 84         return;
 85     if(r-l<d-u)
 86     {
 87         int mid=(u+d)>>1;
 88         for(int i=u;i<=mid;++i)
 89             for(int j=l;j<=r;++j)
 90                 b[i-u+1][j-l+1]=a[i][j];
 91         getD(mid-u+1,r-l+1,f1);
 92         for(int i=mid+1;i<=d;++i)
 93             for(int j=l;j<=r;++j)
 94                 b[i-mid][j-l+1]=a[i][j];
 95         getU(d-mid,r-l+1,f2);
 96         for(int i=l;i<=r;++i)
 97             for(int j=i+1;j<=r;++j)
 98                 ans+=(ll)f1[i-l+1][j-l+1]*f2[i-l+1][j-l+1];
 99         solve(l,r,u,mid);
100         solve(l,r,mid+1,d);
101     }
102     else
103     {
104         int mid=(l+r)>>1;
105         for(int i=u;i<=d;++i)
106             for(int j=l;j<=mid;++j)
107                 b[i-u+1][j-l+1]=a[i][j];
108         getR(d-u+1,mid-l+1,f1);
109         for(int i=u;i<=d;++i)
110             for(int j=mid+1;j<=r;++j)
111             {
112                 b[i-u+1][j-mid]=a[i][j];
113             }
114         getL(d-u+1,r-mid,f2);
115         for(int i=u;i<=d;++i)
116             for(int j=i+1;j<=d;++j)
117                 ans+=(ll)f1[i-u+1][j-u+1]*f2[i-u+1][j-u+1];
118         solve(l,mid,u,d);
119         solve(mid+1,r,u,d);
120     }
121 }
122 int n,m;
123 int main()
124 {
125     freopen("two.in","r",stdin);
126     freopen("two.out","w",stdout);
127     ios::sync_with_stdio(false);
128     cin>>n>>m;
129     for(int i=1;i<=n;++i)
130     {
131         string str;
132         cin>>str;
133         for(int j=1;j<=m;++j)
134             if(str[j-1]=='.')
135                 a[i][j]=0;
136             else
137                 a[i][j]=1;
138     }
139     solve(1,m,1,n);
140     cout<<ans<<endl;
141     return 0;
142 }
View Code
原文地址:https://www.cnblogs.com/GreenDuck/p/14545515.html