【HDU 4747 Mex】线段数

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4747

题意:有一组序列a[i](1<=i<=N), 让你求所有的mex(l,r), mex(l,r)表示区间[l,r]中最小的未在序列中出现的非负整数。

思路:冥思苦想半天无想法,白做了那么多线段树。 很明显的维护区间问题,容易想到线段树,比较难想到操作。 枚举一个序列的所mex(1,i),mex(2,i)……可以发现序列mex(x,i)是一个单调递增序列,我们需要求得就是所有以x开头的序列和,mex(x,i)(x<=i<=n)。这点确定了就好办了,记录每个位置的数后面最早重复出现的位置next[x],如果无则为设n+1。那么我们就可以发现,当第x个数所对应的序列 mex(x,i)(x<=i<=n)所对应的序列求完之后,删去此位置的数,位置x+1~next[x]-1序列中mex值大于a[x]的都改为a[x],因为a[x]没有了,下一个a[x]还未出现,所以可以证明这样做是正确的。从1到n扫一遍亦求出了所有的mex()。

基本上所有的操作都可以用到线段树。开始没有想到一点的是如何找序列中刚好大于a[x]的位置,并且此位置到next[x]-1赋值为a[x],怎么都没想到log(n)的操作,其实这里依然可以用到线段树,因为序列是单调递增的,另开一个区间维护序列mavv[u]表示区间中最大的mex值,随着询问以及其他操作成段更新即可。

  1 #include <iostream>
  2 #include <cstdio>
  3 #include <cmath>
  4 #include <map>
  5 #include <algorithm>
  6 #include <cstring>
  7 #include <sstream>
  8 using namespace std;
  9 
 10 #define lz 2*u,l,mid
 11 #define rz 2*u+1,mid+1,r
 12 typedef long long lld;
 13 const int maxn=222222;
 14 int a[maxn], b[maxn], next[maxn];
 15 lld sum[4*maxn], mavv[4*maxn], flag[4*maxn];
 16 map<int,int>mp;
 17 
 18 void push_up(int u, int l, int r)
 19 {
 20     sum[u]=sum[2*u]+sum[2*u+1];
 21     mavv[u]=mavv[2*u+1];
 22 }
 23 
 24 void push_down(int u, int l, int r)
 25 {
 26     int mid=(l+r)>>1;
 27     if(flag[u]!=-1)
 28     {
 29         flag[2*u]=flag[2*u+1]=flag[u];
 30         mavv[2*u]=mavv[2*u+1]=flag[u];
 31         sum[2*u]=(lld)(mid-l+1)*flag[u];
 32         sum[2*u+1]=(lld)(r-mid)*flag[u];
 33         flag[u]=-1;
 34     }
 35 }
 36 
 37 void build(int u, int l, int r)
 38 {
 39     flag[u]=-1;
 40     int mid=(l+r)>>1;
 41     if(l==r)
 42     {
 43         sum[u]=mavv[u]=b[l];
 44         return ;
 45     }
 46     build(lz);
 47     build(rz);
 48     push_up(u,l,r);
 49 }
 50 
 51 void Update(int u, int l, int r, int tl, int tr, int val)
 52 {
 53     if(tl>tr) return ;
 54     if(tl<=l&&r<=tr)
 55     {
 56         mavv[u]=val;
 57         sum[u]=(lld)val*(r-l+1);
 58         flag[u]=val;
 59         return ;
 60     }
 61     push_down(u,l,r);
 62     int mid=(l+r)>>1;
 63     if(tr<=mid) Update(lz,tl,tr,val);
 64     else if(tl>mid) Update(rz,tl,tr,val);
 65     else
 66     {
 67         Update(lz,tl,mid,val);
 68         Update(rz,mid+1,tr,val);
 69     }
 70     push_up(u,l,r);
 71 }
 72 
 73 int find(int u, int l, int r, int tmp)
 74 {
 75     if(l==r) return l;
 76     push_down(u,l,r);
 77     int mid=(l+r)>>1;
 78     if(mavv[2*u]>tmp) return find(lz,tmp);
 79     else return find(rz,tmp);
 80 }
 81 
 82 int main()
 83 {
 84     int n;
 85     while(cin >> n,n)
 86     {
 87         for(int i=1; i<=n; i++) scanf("%d",a+i);
 88         mp.clear();
 89         for(int i=n; i>=1; i--)
 90         {
 91             if(mp[ a[i] ]) next[i]=mp[ a[i] ];
 92             else next[i]=n+1;
 93             mp[ a[i] ]=i;
 94         }
 95         mp.clear();
 96         int x=0;
 97         for(int i=1; i<=n; i++)
 98         {
 99             mp[ a[i] ]=1;
100             while(mp[x]) ++x;
101             b[i]=x;
102         }
103         build(1,1,n);
104         lld ans=0;
105         for(int i=1; i<=n; i++)
106         {
107             ans+=sum[1];
108             if(mavv[1]>a[i])
109             {
110                 int id=find(1,1,n,a[i]);
111                 Update(1,1,n,max(id,i+1),next[i]-1,a[i]);
112             }
113             Update(1,1,n,i,i,0);
114         }
115         cout << ans <<endl;
116     }
117 }
View Code
原文地址:https://www.cnblogs.com/kane0526/p/3329169.html