zoj 3278 8G Island

二分

比较好的二分题目,需要花点脑筋想到,另外写的细节也多

题意:比较好懂,a数组有n个元素,b数组有m个元素,a数组的元素分别和b数组的元素相乘得到新的元素,那么一共会得到n*m个元素,将这些元素降序排序,找到第k大的元素是谁

为检验算法正确性,一个暴力的程序很容易写出来,关键是正解是什么,ab数组的元素个数都很多,暴力会超时,而且空间也不允许

使用二分,而且是嵌套的二分,两个二分的目的不同

外层二分是位了枚举答案,内层二分是检验当前的答案是否符合

做法:先将ab数组降序排序,那么我们知道a[1]*b[1]是最大值,a[n]*b[m]是最小值,其他元素的乘积一定在这个区间内,我们并不知道答案是多少,所以我们就枚举答案,用二分来枚举这个区间内的答案

ab元素的最大值是10^5,相乘是10^10,所以区间长度最长为10^10,二分的话就是30多次

接下来,枚举得到一个答案key,就去检验这个答案是否符合,我的做法和网上大部分的做法不同,我内层写了两个二分,他们都只写了1个,其实本质是一样的,下面说说我的做法

首先枚举a数组的每个元素,然后在b数组中找一个元素,这个元素的要求是,a[i]*b[j] > key,并且这个j是最后一个j,即a[i]*b[j+1] <= key,也就是找到最后一个b[j],和当前的a[i]相乘后的值大于当前的答案key(大于,而不是大于等于)

同理,要在b数组中找到一个元素b[k],a[i]*b[k] < key ,并且a[i] * b[k-1] >= key , 即找到一个最前的一个b[k]使得它和a[i]相乘的积小于当前的答案key(小于,不是小于等于)

那么我们就可以知道b数组中有多少个元素和a[i]相乘的积 = key

我们枚举了全部的a[i],就是为了知道所有n*m个积中,有多少>key , 有多少<key , n*m这个总数减去之后就可以知道知道多少有多少个元素和key相等了

也就是我们可以确定key在哪个区间

好像现在有10个积按降序排好了,已知key = 6 , 知道大于key的元素有2个,小于key的元素有4个,那么我们知道一定有4个元素和key=6相等,也就是说key一个出现在[3,6]这个位置内(总位置为[1,10])。比如我们要求第5大的数,那么这个数一定在第5位,而恰好就在[3,6]内,所以我们知道第5大的数字就是key = 6

如果我们要求第2大的数,发现不在[3,6]这些位置内,那么key=6一定不是第2大的数,而且可以知道第2大的数一定比key=6大,所以我们就可以缩短一半的范围去找另一个key(这就是外层循环做的事情了)

如果我们要找第8大的数,发现不在[3,6]这些位置内,那么key=6一定不是第8大的数,而且可以key=6大于第8大的数,所以我们可以缩短一半的范围去找另一个key(这就是外层循环做的事情了)

所以内存二分的时间复杂度是  n * log m ,所以总的时间复杂度是  log MAX * n * log m

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

typedef long long ll;
const ll INF = 100000000010;
const int N = 100010;

ll n,m,kth,Left,Right;
ll a[N],b[N];
ll Min,Max;

bool cmp(int x ,int y)
{
    return x > y;
}

int main()
{
    while(cin >> n >> m >> kth)
    {
        for(int i=1; i<=n; i++) cin >> a[i];//scanf("%I64d",&a[i]);
        for(int i=1; i<=m; i++) cin >> b[i];//scanf("%I64d",&b[i]);
        sort(a+1,a+1+n,cmp);
        sort(b+1,b+1+m,cmp);
        b[0] = INF; b[m+1] = -1; //这两句要有
        Max= a[1] * b[1];
        Min = a[n] * b[m];
        ll Low = Min; 
        ll High = Max;
        ll ans = Max;
        
        while(Low <= High)
        {
            ll Mid = (Low + High) >> 1;
            ll key = Mid;

            ll s1 = 0 , s2 = 0;
            for(int i=1; i<=n; i++)
            {
                int low,high,res;
                ll val,__val;
                val = a[i] * b[1];
                __val = a[i] * b[m];
                if(__val > key) //这两个特判能砍掉一半的时间
                {
                    s1 += m;
                    continue;
                }
                if(val < key) //这两个特判能砍掉一半的时间
                {
                    s2 += m;
                    continue;
                }

                /******************找出最后一个大于key的元素下标,保存在res中,累加到s1中*************/
                low = 1; high = m; res = 0; //注意res的初始化
                while(low <= high)
                {
                    int mid = (low + high) >> 1;
                    val = a[i] * b[mid];
                    __val = a[i] * b[mid+1];
                    if(val > key && __val <= key)
                    { res = mid; break; }
                    else if(__val > key) 
                        low = mid + 1;
                    else
                        high = mid - 1;
                }
                s1 += res;
                /******************找出最后一个大于key的元素下标,保存在res中,累加到s1中**********/

                
                /******************找出第一个小于key的元素下标,保存在res中,累加到s2中***********/
                low = 1; high = m; res = m+1; //注意res的初始化
                while(low <= high)
                {
                    int mid = (low + high) >> 1;
                    val = a[i] * b[mid];
                    __val = a[i] * b[mid-1];
                    if(val < key && __val >= key)
                    { res = mid; break;}
                    else if(__val < key)
                        high = mid - 1;
                    else
                        low = mid + 1;
                }
                s2 += (m - res + 1);
                /******************找出第一个小于key的元素下标,保存在res中,累加到s2中*********/
            }

            Left = s1 + 1;
            Right = n*m - s2;

            if(Left <= kth && kth <= Right)
            { ans = key; break;}
            else if(Left > kth)
                Low = Mid + 1;
            else
                High = Mid - 1;
        }
        cout << ans << endl;
    }
    return 0;
}
原文地址:https://www.cnblogs.com/scau20110726/p/3127366.html