BZOJ 1112: [POI2008]砖块Klo Splay + 性质分析

Code: 

#include<bits/stdc++.h>
using namespace std; 
#define setIO(s) freopen(s".in","r",stdin) 
#define maxn 200000 
#define inf 10000000000000 
#define get(x) (ch[f[x]][1]==x)  
#define ll long long 
int root,cc,kk; 
int h[maxn],ch[maxn][2],f[maxn],siz[maxn],val[maxn];         
long long sumv[maxn];  
long long ans=inf; 
void pushup(int x) 
{ 
    sumv[x]=sumv[ch[x][0]]+sumv[ch[x][1]]+h[x];   
    siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+1; 
}
void rotate(int x) 
{
    int old=f[x],fold=f[old],which=get(x); 
    ch[old][which]=ch[x][which^1],f[ch[old][which]]=old; 
    ch[x][which^1]=old,f[old]=x,f[x]=fold; 
    if(fold) ch[fold][ch[fold][1]==old]=x; 
    pushup(x),pushup(old); 
} 
void splay(int x,int &tar) 
{
    int u=f[tar]; 
    for(int fa;(fa=f[x])!=u;rotate(x)) 
        if(f[fa]!=u) 
            rotate(get(fa)==get(x)?fa:x); 
    tar=x;  
}                
void insert(int &x,int key,int ff)
{ 
    if(!x) x=key,f[x]=ff;   
    else insert(ch[x][h[key]>h[x]],key,x); 
    pushup(x); 
}
int query(int kth) 
{ 
    int x=root;  
    while(1) 
    { 
        if(siz[ch[x][0]]+1==kth) return x; 
        if(siz[ch[x][0]]>=kth) x=ch[x][0]; 
        else kth-=(siz[ch[x][0]]+1),x=ch[x][1];   
    }
}
void del(int x) 
{
    if(!ch[x][0]) root=ch[x][1], f[root]=ch[x][1]=0;    
    else if(!ch[x][1]) root=ch[x][0], f[root]=ch[x][0]=0;  
    else 
    {
        int l=ch[x][0]; 
        while(ch[l][1]) l=ch[l][1];   
        splay(l, ch[x][0]);   
        ch[l][1]=ch[x][1], f[ch[x][1]]=l, f[l]=ch[x][0]=ch[x][1]=0, pushup(l);
        root=l; 
    }
}
void solve(int k,int L)
{ 
    int mid=(k%2==0)?k/2:(k/2)+1,x=query(mid);     
    splay(x,root);  
    int l=ch[root][0],r=ch[root][1]; 
    ll re=0;  
    re+=(h[x]*siz[l]-sumv[l]);   
    re+=(sumv[r]-h[x]*siz[r]);   
    if(re<ans) ans=re, cc=L, kk=h[x];   
}
int main()
{
    // setIO("input");  
    int n,k,i,j; 
    scanf("%d%d",&n,&k); 
    for(i=1;i<=n;++i) scanf("%d",&h[i]);                           
    for(i=1;i<=k;++i) { insert(root,i,0); if(i%6==0) splay(i,root); }   
    solve(k,1);   
    for(i=k+1;i<=n;++i) 
    {
        j=i-k+1;     
        splay(j-1, root), del(j-1), insert(root,i,0), splay(i,root),  solve(k,j);             
    }
    printf("%lld
",ans);        
    return 0; 
}

  

原文地址:https://www.cnblogs.com/guangheli/p/11187007.html