hdu 4747 Mex

http://acm.hdu.edu.cn/showproblem.php?pid=4747

设我们输入的数组为 a[],我们需要从 1 到 n 遍历, 假设遍历到 i 时, 遍历的过程中用b[j]表示从 i 到 j 没出现的最小自然数

先从 n 到 1 扫一遍求出从 1 到各个点的b[j]值

然后遍历a[] 实际上就是不断的把当前a[i] 去掉,比如说去掉a[3]时,剩下的b[4]---b[n] 就表示从4到其他后续点形成的区间中没出现的最小自然数

要知道从 i 到 n ,b[]的值始终是单调递增的

我们每去掉当前a[i]会对b[]数组产生影响,

设下一个和a[i]相等的数出现的位置是 r 那么去掉a[i] 对 r 以及 r 以后的b[] 没有影响

在 i 和 r 之间受影响的段b[]是大于等于a[i]的那一段 假设是(l,r), 这个段内的b[]都大于等于a[i]

去掉a[i]的影响就是这个段内的b[] 都要等于 a[i]

找到r可以事先标记,找 l 和更新段 (l,r) 有两种方法

1,二分找到 l ,然后遍历更新段 (l,r)    这样代码比较短,也比较易懂,但比较耗时,不过可以过

2,线段树维护                                这样代码量会比较大,不过耗时少,线段树的解法应该比较标准

两种代码:

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<string>
#include<cstring>
#include<cmath>
#include<set>
#include<vector>
#include<list>
#include<stack>
#include<queue>
#include<map>

using namespace std;

typedef long long ll;
typedef pair<int,int> pp;

const int INF=0x3f3f3f3f;

const int N=200002;
bool exist[N];
int a[N],next[N],f[N];
int b[N];
int bsh(int l,int r,int k)
{
    while(l<=r)
    {
        int mid=(l+r)>>1;
        if(b[mid]<=k) l=mid+1;
        else r=mid-1;
    }
    return r;
}
int main()
{
    //freopen("data.in","r",stdin);
    int n;
    while(scanf("%d",&n)!=EOF)
    {
       if(n==0) break;
       for(int i=1;i<=n;++i)
       scanf("%d",&a[i]);
       for(int i=0;i<=n;++i)
       f[i]=n+1;
       for(int i=n;i>=1;--i)
       if(a[i]<n)
       {
           next[i]=f[a[i]];
           f[a[i]]=i;
       }
       ll ans=0;
       memset(exist,false,sizeof(exist));
       ll tmp=0;int l=0;
       for(int i=1;i<=n;++i)
       {
           if(a[i]<n)
           {
               exist[a[i]]=true;
               while(exist[l]) ++l;
           }
           b[i]=l;
           tmp+=b[i];
       }
       ans=tmp;
       for(int i=1;i<n;++i)
       {
           if(a[i]<n)
           {
               int r=next[i];
               int l=bsh(i,r-1,a[i]);
               for(int j=l+1;j<r;++j)
               {
                   tmp-=(b[j]-a[i]);
                   b[j]=a[i];
               }
           }
           tmp-=b[i];
           ans+=tmp;
       }
       cout<<ans<<endl;
    }
    return 0;
}

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<string>
#include<cstring>
#include<cmath>
#include<set>
#include<vector>
#include<list>
#include<stack>
#include<queue>
#include<map>

using namespace std;

typedef long long ll;
typedef pair<int,int> pp;

const int INF=0x3f3f3f3f;

const int N=200002;
bool exist[N];
int a[N],next[N],f[N];
int b[N];
struct node
{
    int l,r,k,least;
    ll sum;
}tr[N*4];
void build(int x,int l,int r)
{
    tr[x].l=l;tr[x].r=r;tr[x].k=-1;
    if(l==r)
    {
        tr[x].least=b[l];
        tr[x].sum=b[l];
        return ;
    }
    int mid=(l+r)>>1;
    build((x<<1),l,mid);
    build((x<<1)|1,mid+1,r);
    tr[x].least=min(tr[x<<1].least,tr[(x<<1)|1].least);
    tr[x].sum=(tr[x<<1].sum+tr[(x<<1)|1].sum);
}
void update(int x,int l,int r,int k)
{
    if(l>r) return ;
    if(tr[x].l==l&&tr[x].r==r)
    {
        tr[x].least=k;
        tr[x].k=k;
        tr[x].sum=(ll)k*(tr[x].r-tr[x].l+1);
        return ;
    }
    if(tr[x].k!=-1)
    {
        tr[x<<1].k=tr[x].k;tr[x<<1].least=tr[x<<1].k;
        tr[x<<1].sum=(ll)tr[x<<1].k*(tr[x<<1].r-tr[x<<1].l+1);
        tr[(x<<1)|1].k=tr[x].k;tr[(x<<1)|1].least=tr[(x<<1)|1].k;
        tr[(x<<1)|1].sum=(ll)tr[(x<<1)|1].k*(tr[(x<<1)|1].r-tr[(x<<1)|1].l+1);
        tr[x].k=-1;
    }
    int mid=(tr[x].l+tr[x].r)>>1;
    if(r<=mid)
    update(x<<1,l,r,k);
    else if(l>mid)
    update((x<<1)|1,l,r,k);
    else
    {
        update(x<<1,l,mid,k);
        update((x<<1)|1,mid+1,r,k);
    }
    tr[x].least=min(tr[x<<1].least,tr[(x<<1)|1].least);
    tr[x].sum=(tr[x<<1].sum+tr[(x<<1)|1].sum);
    tr[x].k=-1;
}
int get(int x,int l,int r,int w)
{
    if(tr[x].l==tr[x].r)
    {
        if(tr[x].least>w)
        return (l-1);
        return l;
    }
    if(tr[x].k!=-1)
    {
        tr[x<<1].k=tr[x].k;tr[x<<1].least=tr[x<<1].k;
        tr[x<<1].sum=(ll)tr[x<<1].k*(tr[x<<1].r-tr[x<<1].l+1);
        tr[(x<<1)|1].k=tr[x].k;tr[(x<<1)|1].least=tr[(x<<1)|1].k;
        tr[(x<<1)|1].sum=(ll)tr[(x<<1)|1].k*(tr[(x<<1)|1].r-tr[(x<<1)|1].l+1);
        tr[x].k=-1;
    }
    int mid=(tr[x].l+tr[x].r)>>1;
    if(r<=mid)
    return get(x<<1,l,r,w);
    else if(l>mid)
    return get((x<<1)|1,l,r,w);
    else
    {
        if(tr[(x<<1)|1].least<=w)
        return get((x<<1)|1,mid+1,r,w);
        else
        return get(x<<1,l,mid,w);
    }
}
ll gsum(int x,int l,int r)
{
    if(l>r) return 0;

    if(tr[x].l==l&&tr[x].r==r)
    return tr[x].sum;
    if(tr[x].k!=-1)
    {
        tr[x<<1].k=tr[x].k;tr[x<<1].least=tr[x<<1].k;
        tr[x<<1].sum=(ll)tr[x<<1].k*(tr[x<<1].r-tr[x<<1].l+1);
        tr[(x<<1)|1].k=tr[x].k;tr[(x<<1)|1].least=tr[(x<<1)|1].k;
        tr[(x<<1)|1].sum=(ll)tr[(x<<1)|1].k*(tr[(x<<1)|1].r-tr[(x<<1)|1].l+1);
        tr[x].k=-1;
    }
    int mid=(tr[x].l+tr[x].r)>>1;
    if(r<=mid)
    return gsum(x<<1,l,r);
    else if(l>mid)
    return gsum((x<<1)|1,l,r);
    else
    return gsum(x<<1,l,mid)+gsum((x<<1)|1,mid+1,r);
}
int main()
{
    int n;
    while(scanf("%d",&n)!=EOF)
    {
       if(n==0) break;
       for(int i=1;i<=n;++i)
       scanf("%d",&a[i]);
       for(int i=0;i<=n;++i)
       f[i]=n+1;
       for(int i=n;i>=1;--i)
       if(a[i]<n)
       {
           next[i]=f[a[i]];
           f[a[i]]=i;
       }
       ll ans=0;
       memset(exist,false,sizeof(exist));
       int l=0;
       for(int i=1;i<=n;++i)
       {
           if(a[i]<n)
           {
               exist[a[i]]=true;
               while(exist[l]) ++l;
           }
           b[i]=l;
       }
       build(1,1,n);
       ans+=gsum(1,1,n);
       for(int i=1;i<n;++i)
       {
           if(a[i]<n)
           {
               int r=next[i];
               int l=get(1,i,r-1,a[i]);
               update(1,l+1,r-1,a[i]);
           }
           ans+=gsum(1,i+1,n);
       }
       cout<<ans<<endl;
    }
    return 0;
}
原文地址:https://www.cnblogs.com/liulangye/p/3330003.html