[loj6484]LJJ爱数书

先考虑没有区间,即对于长为$n$的序列${a_{1},a_{2},...,a_{n}}$(以下记$a_{0}=a_{n+1}=0$),求$F(a,k)$

问题即构造序列$b_{i}$,满足$forall 0le ile n,b_{i}equiv a_{i}-a_{i+1}(mod k)$且$sum_{i=0}^{n}b_{i}=0$,并最小化$frac{sum_{i=0}^{n}|b_{i}|}{2}$

最优的$b_{i}$满足$|b_{i}|<k$,否则不妨假设$b_{i}ge k$($b_{i}le -k$类似),将其减小$k$,并将一个$b_{j}<0$增加$k$(总存在,否则不满足$sum_{i=0}^{n}b_{i}=0$),显然$frac{sum_{i=0}^{n}|b_{i}|}{2}$严格减小

又因为$b_{i}equiv a_{i}-a_{i+1}(mod k)$,显然其最后的取值仅有两种,不妨都先取较小的一种,再选择$-frac{sum_{i=0}^{n}b_{i}}{k}$个位置增加$k$(取另一种),显然贪心选择收益最高(即$b_{i}$最小)的位置即可

(特别的,当$a_{i}-a_{i+1}equiv 0(mod k)$,令较小的取值为0,另一种取值为$k$,显然不会取到)

下面,问题变为一个区间,我们来维护上面的过程——

更准确的来说,由于$a_{i}in [0,k)$,这个较小的取值即$egin{cases}a_{i}-a_{i+1}&(a_{i}le a_{i+1})\a_{i}-a_{i+1}-k&(a_{i}>a_{i+1})end{cases}$

求出区间中$S=-sum_{i=l-1}^{r}b_{i}$(其中$b_{l-1}$和$b_{r}$要特判),首先答案以$S$为基础上修改,其次要选择$frac{S}{k}$个位置

关于如何找到这个位置,来二分这个$b_{i}$最小,并对两类分别统计(同样特判$b_{l-1}$和$b_{r}$),用可持久化线段树维护,以及再求出区间和即可

(直接在可持久化线段树上二分似乎并不太行,因为有两段)

总复杂度为$o(qlog^{2}n)$,可以通过

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 #define N 200005
 4 #define ll long long
 5 #define mid (l+r>>1)
 6 #define pil pair<int,ll>
 7 #define fi first
 8 #define se second
 9 pil f[N*60];
10 int V,n,m,q,x,y,z,a[N],b[N],rt1[N],rt2[N],ls[N*60],rs[N*60];
11 int New(int k){
12     f[++V]=f[k];
13     ls[V]=ls[k];
14     rs[V]=rs[k];
15     return V;
16 }
17 pil add(pil x,pil y){
18     return make_pair(x.fi+y.fi,x.se+y.se);
19 }
20 pil dec(pil x,pil y){
21     return make_pair(x.fi-y.fi,x.se-y.se);
22 }
23 void update(int &k,int l,int r,int x){
24     k=New(k);
25     f[k].fi++,f[k].se+=x;
26     if (l==r)return;
27     if (x<=mid)update(ls[k],l,mid,x);
28     else update(rs[k],mid+1,r,x);
29 }
30 pil query(int k,int l,int r,int x,int y){
31     if ((!k)||(l>y)||(x>r))return make_pair(0,0);
32     if ((x<=l)&&(r<=y))return f[k];
33     return add(query(ls[k],l,mid,x,y),query(rs[k],mid+1,r,x,y));
34 }
35 pil calc(int x,int y,int l,int k){
36     pil o1=dec(query(rt1[y-1],0,m,1,k-l),query(rt1[x-1],0,m,1,k-l));
37     pil o2=dec(query(rt2[y-1],0,m,l,k-1),query(rt2[x-1],0,m,l,k-1));
38     pil o=make_pair(o1.fi+o2.fi,1LL*k*o1.fi-o1.se+o2.se);
39     if ((a[y])&&(a[y]<=k-l)){
40         o.fi++;
41         o.se+=k-a[y];
42     }
43     if (l<=a[x]){
44         o.fi++;
45         o.se+=a[x];
46     }
47     return o;
48 }
49 int main(){
50     m=(1<<30)-1;
51     scanf("%d%d",&n,&q);
52     for(int i=1;i<=n;i++)scanf("%d",&a[i]);
53     for(int i=0;i<=n;i++)b[i]=a[i]-a[i+1];
54     for(int i=1;i<n;i++){
55         rt1[i]=rt1[i-1];
56         if (b[i]>0)update(rt1[i],0,m,b[i]);
57     }
58     for(int i=1;i<n;i++){
59         rt2[i]=rt2[i-1];
60         if (b[i]<=0)update(rt2[i],0,m,-b[i]);
61     }
62     for(int i=1;i<=q;i++){
63         scanf("%d%d%d",&x,&y,&z);
64         pil o1=dec(f[rt1[y-1]],f[rt1[x-1]]);
65         pil o2=dec(f[rt2[y-1]],f[rt2[x-1]]);
66         if (!a[y])o2.fi++;
67         else{
68             o1.fi++;
69             o1.se+=a[y];
70         }
71         o2.fi++,o2.se+=a[x];
72         ll ans=1LL*z*o1.fi-o1.se+o2.se;
73         int l=1,r=z-1;
74         while (l<r){
75             int midd=(l+r+1>>1);
76             if (calc(x,y,midd,z).fi>=ans/z)l=midd;
77             else r=midd-1;
78         }
79         pil o=calc(x,y,l,z);
80         o.se-=(o.fi-ans/z)*l,o.fi=ans/z;
81         ans+=1LL*z*o.fi-2*o.se;
82         printf("%lld
",ans/2);
83     }
84 }
View Code
原文地址:https://www.cnblogs.com/PYWBKTDA/p/14803186.html