题解「CF1344D Resume Review」

链接

CodeForces 1344D

题意

一个长度为 (n) 的数组 (a_i),构造 自然数 数组满足:

  • (forall i,b_iin[0,a_i]).
  • (sum_{i=1}^n b_i=k)

在这个前提下,求令 (sum_{i=1}^n b_i(a_i-b_i^2)) 达到最大值的任意一组 (b_i).

题解

首先很容易想到暴力,设 (b_i) 的值为 (x) ,则 (b_i) 增加 (1) 对答案新增的贡献为 (Delta_i=(x+1)left(a_i-(x-1^2) ight)-x(a_i-x^2)=a_i-3x^2+3x-1)

那么,初始时设所有 (b_i) 都为 (0) ,用一个堆维护,每次从堆中取出 (Delta_i) 最大的 (b_i)(1),并记录下 (b_i) ,最后得到的 (b_i) 就是题目所求的 (b_i) 。时间复杂度为 (O(k log n)),显然不能承受。

观察上述过程,我们发现在暴力过程中,由于每次都采取当前的最优策略,所以对于任意的 (y),第 (y) 次增加的值是 固定的 ,并且是单调递减的。那么这样的一个增量便具有了单调性。

对于一个 (Delta_i),我们知道如果取这个 (i) 把它对应的 (b_i)(1),意味着其他位置上的 (Delta) 值并不会大于 (Delta_i),也就是说我们可以根据这个 (Delta_i) ,大概确定出所有的 (b_i)。我们称这个 (Delta_i) 为最大增量.

由于满足单调性,最大增量可以被二分出来,( ext{check}) 函数也非常简单,根据最大增量求出每个 (b_i) ,判断 (sum_{i=1}^n b_i) 是否大于 (k) 即可。至于求 (b_i) 的过程可以解方程,也可以直接二分,前者可以近似认为是 (O(n log maxVal)),而后者则是 (O(n log^2 maxVal))

值得一提的是,由于最大增量 非严格单调递减 ,以及二分后求出的 (b_i) 不一定满足 (sum b_i=k) 的约束,我们规定二分求出的 (b_i) 满足 (sum b_ileq k),并且二分求出的 (Delta) 不是一个值,是一个长度为 (2) 的区间,对边界进行控制,这样就不会遗漏任何情况。

Show the Code

#include<cstdio>
#include<climits>
#define max(a,b) ((a)>(b)? (a):(b))
#define min(a,b) ((a)<(b)? (a):(b))
typedef long long ll;
ll n,m,sum=0;
ll a[100005],b[100005];
inline ll read() {
    register ll x=0,f=1;register char s=getchar();
    while(s>'9'||s<'0') {if(s=='-') f=-1;s=getchar();}
    while(s>='0'&&s<='9') {x=x*10+s-'0';s=getchar();}
    return x*f;
}
//-3*y^2+3*y+x-1=-3(y-(1/2))^2+...
inline ll f(ll x,ll y) {return x==y? LLONG_MAX:x-3*y*(y-1)-1;}//对称轴x=1/2
inline ll calc(ll x,ll lim) {//f(a[x],b[x])<=lim
    ll l=1,r=a[x],res=a[x];
    while(l<=r) {
        int mid=l+r>>1;
        if(f(a[x],mid)<=lim) {r=mid-1;res=mid;}
        else {l=mid+1;} 
    }
    return res;
}
inline bool check(ll x) {
    sum=0;
    for(register int i=1;i<=n;++i) {
        b[i]=calc(i,x);
        sum+=b[i];
    }
    return sum<m;
}
int main() {
    n=read();m=read();
    ll l=0,r=0;
    for(register int i=1;i<=n;++i) {
        a[i]=read();
        l=min(l,f(a[i],a[i]-1));
        r=max(r,f(a[i],0));
    }
    while(l+1<r) {
        ll mid=l+r>>1;
        if(check(mid)) r=mid;
        else l=mid;
    }
    if(check(l)) r=l;
    check(r);
    m-=sum;
    for(register int i=1;i<=n;++i) if(m>0&&b[i]<a[i]&&f(a[i],b[i])==r) ++b[i],--m; 
    for(register int i=1;i<=n;++i) printf("%lld ",b[i]);
    return 0;
}
原文地址:https://www.cnblogs.com/tommy0103/p/13050486.html