洛谷P1527 矩阵乘法——二维树状数组+整体二分

题目:https://www.luogu.org/problemnew/show/P1527

整体二分,先把所有询问都存下来;

然后二分一个值,小于它的加到二维树状数组的前缀和里,判断一遍所有询问,就分出了这些询问的答案是否大于这个值;

然后分组递归下去求解即可;

注意加二维树状数组的那个nw是全局变量,在不同的层中不停调整;

二分的范围最好是mn-1到mx+1,不然有些询问的ans会没有赋上值。

代码如下:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
int const maxn=505*505,maxm=60005;
int n,m,f[505][505],nn,mx,mn,nw,ans[maxm],p[maxm],tmp[maxm];
bool mk[maxm];
struct N{int x,y,val;}a[maxn];
struct Q{int x1,x2,y1,y2,k;}q[maxm];
bool cmp(N x,N y){return x.val<y.val;}
void add(int x,int y,int val)
{
    for(int i=x;i<=n;i+=(i&-i))
        for(int j=y;j<=n;j+=(j&-j))
            f[i][j]+=val;
}
int query(int x,int y)
{
    int ret=0;
//    for(int i=1;i<=x;i+=(i&-i))
//        for(int j=1;j<=y;j+=(j&-j))
//            ret+=f[i][j];
    for(int i=x;i;i-=(i&-i))
        for(int j=y;j;j-=(j&-j))
            ret+=f[i][j];
    return ret;
}
int ask(int i)
{
    int x1=q[i].x1,x2=q[i].x2,y1=q[i].y1,y2=q[i].y2;
    return query(x1-1,y1-1)+query(x2,y2)-query(x1-1,y2)-query(x2,y1-1);
}
void solve(int l,int r,int L,int R)
{
    if(l>r||L==R)return;
    int mid=((L+R)>>1);
    while(a[nw+1].val<=mid&&nw<nn)add(a[nw+1].x,a[nw+1].y,1),nw++;//nw是全局变量 
    while(a[nw].val>mid)add(a[nw].x,a[nw].y,-1),nw--;
    int cnt=0;
    for(int i=l;i<=r;i++)
    {
        if(ask(p[i])>=q[p[i]].k)mk[i]=1,cnt++,ans[p[i]]=mid;//若p[i]中小于mid的数比k多,说明第k小的数比mid小 
        else mk[i]=0;
    }
    int l1=l-1,l2=l+cnt-1;
    for(int i=l;i<=r;i++)
    {
        if(mk[i])tmp[++l1]=p[i];//二分 
        else tmp[++l2]=p[i];//不是i!! 
    }
    for(int i=l;i<=r;i++)p[i]=tmp[i];
    solve(l,l1,L,mid);solve(l1+1,r,mid+1,R);
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
        {
            scanf("%d",&a[++nn].val);
            a[nn].x=i;a[nn].y=j;
            mn=min(mn,a[nn].val);mx=max(mx,a[nn].val);
        }    
    sort(a+1,a+nn+1,cmp);
    for(int i=1;i<=m;i++)
        scanf("%d%d%d%d%d",&q[i].x1,&q[i].y1,&q[i].x2,&q[i].y2,&q[i].k),p[i]=i;
    solve(1,m,mn-1,mx+1);
    for(int i=1;i<=m;i++)printf("%d
",ans[i]);
    return 0;
} 
原文地址:https://www.cnblogs.com/Zinn/p/9174006.html