BZOJ2738: 矩阵乘法(整体二分)

Description

  给你一个N*N的矩阵,不用算矩阵乘法,但是每次询问一个子矩形的第K小数。

Input

  第一行两个数N,Q,表示矩阵大小和询问组数;
  接下来N行N列一共N*N个数,表示这个矩阵;
  再接下来Q行每行5个数描述一个询问:x1,y1,x2,y2,k表示找到以(x1,y1)为左上角、以(x2,y2)为右下角的子矩形中的第K小数。

Output

  对于每组询问输出第K小的数。

Sample Input

2 2
2 1
3 4
1 2 1 2 1
1 1 2 2 3

Sample Output

1
3

HINT 

  矩阵中数字是109以内的非负整数;

  20%的数据:N<=100,Q<=1000;

  40%的数据:N<=300,Q<=10000;

  60%的数据:N<=400,Q<=30000;

  100%的数据:N<=500,Q<=60000。

解题思路:

整体二分思想非常浓。

将点权排序,二分mid时加入前mid个答案。

在二维树状数组上+1

最后二维树状数组查询前缀合就知道区间有多少个数。

最后统计答案就好了。

代码:

  1 #include<cstdio>
  2 #include<cstring>
  3 #include<algorithm>
  4 struct data{
  5     int i,j;
  6     int val;
  7 }d[500000];
  8 struct que{
  9     int px,py,qx,qy;
 10     int no;
 11     int val;
 12 }q[100010],ss[100010],sp[100010];
 13 int line[501][501];
 14 int n,Q;
 15 int cnt;
 16 int top;
 17 int ans[1000000];
 18 int lowbit(int x)
 19 {
 20     return x&(-x);
 21 }
 22 void update(int x,int y,int v)
 23 {
 24     for(int i=x;i<=n;i+=lowbit(i))
 25     {
 26         for(int j=y;j<=n;j+=lowbit(j))
 27         {
 28             line[i][j]+=v;
 29         }
 30     }
 31     return ;
 32 }
 33 int query(int x,int y)
 34 {
 35     int ans=0;
 36     for(int i=x;i;i-=lowbit(i))
 37     {
 38         for(int j=y;j;j-=lowbit(j))
 39         {
 40             ans+=line[i][j];
 41         }
 42     }
 43     return ans;
 44 }
 45 bool cmp(data x,data y)
 46 {    
 47     return x.val<y.val;
 48 }
 49 void Insert(int no,int dir)
 50 {
 51     update(d[no].i,d[no].j,dir);
 52     return ;
 53 }
 54 int sum(int no)
 55 {
 56     int ret=0;
 57     ret=query(q[no].qx,q[no].qy)+query(q[no].px-1,q[no].py-1);
 58     ret-=query(q[no].qx,q[no].py-1)+query(q[no].px-1,q[no].qy);
 59     return ret;
 60 }
 61 void macrs(int l,int r,int ll,int rr)
 62 {
 63     if(ll>rr)
 64         return ;
 65     if(l==r)
 66     {
 67         for(int i=ll;i<=rr;i++)
 68             ans[q[i].no]=d[l].val;
 69         return ;
 70     }
 71     int mid=(l+r)>>1;
 72     while(top<mid)
 73         Insert(++top,1);
 74     while(top>mid)
 75         Insert(top--,-1);
 76     int sta1=0,sta2=0;
 77     for(int i=ll;i<=rr;i++)
 78     {
 79         if(sum(i)>=q[i].val)
 80             sp[++sta1]=q[i];
 81         else
 82             ss[++sta2]=q[i];
 83     }
 84     int sta=ll-1,lmid;
 85     for(int i=1;i<=sta1;i++)
 86         q[++sta]=sp[i];
 87     lmid=sta;
 88     for(int i=1;i<=sta2;i++)
 89         q[++sta]=ss[i];
 90     macrs(l,mid,ll,lmid);
 91     macrs(mid+1,r,lmid+1,rr);
 92     return ;
 93 }
 94 int main()
 95 {
 96     scanf("%d%d",&n,&Q);
 97     for(int i=1;i<=n;i++)
 98         for(int j=1;j<=n;j++)
 99         {
100             int tmp;
101             scanf("%d",&tmp);
102             d[++cnt]=(data){i,j,tmp};
103         }
104     std::sort(d+1,d+cnt+1,cmp);
105     for(int i=1;i<=Q;i++)
106     {
107         scanf("%d%d%d%d%d",&q[i].px,&q[i].py,&q[i].qx,&q[i].qy,&q[i].val);
108         if(q[i].px>q[i].qx)
109             std::swap(q[i].px,q[i].qx);
110         if(q[i].py>q[i].qy)
111             std::swap(q[i].py,q[i].qy);
112         q[i].no=i;
113     }
114     macrs(1,cnt,1,Q);
115     for(int i=1;i<=Q;i++)
116         printf("%d
",ans[i]);
117     return 0;
118 }
原文地址:https://www.cnblogs.com/blog-Dr-J/p/10116072.html