HDU 4747 Mex(线段树)

http://acm.hdu.edu.cn/showproblem.php?pid=4747

题意:

给出一段数据,求出所有区间的Mex和。

思路:
这道题目很不错,参考了大神博客http://www.cnblogs.com/Griselda/archive/2013/11/20/3433595.html

先计算出mex【i】(表示1~i之间的mex),这肯定是非递减的。

接下来每次将最左边的数删掉,删掉这个数之后,有些区间可能会受到影响,有些不会,那么哪些会受到影响呢?

先找到该数的下一个出现位置next,那么在next之前的区间就有可能会受到影响,而next后面的就一定不会受到影响,因为next【i】可以代替删去的那个数了。

线段树的做法是这样的:

sum【】保存的是【当前左端点,>=左端点】区间的mex之和,那么sum【1】表示的就是 当前左端点到其余各个点的mex之和。(比如左端点为2,那么sum【1】=mex(2,2)+mex(2,3)+mex(2,4)+...+mex(2,n))。

mx【】保存的是当前这些区间内的最大mx值。

每次删掉左端点之后,我们就要去判断接下来的区间中mex值是否会有变化,如果mx【1】大于了当前左端点,那么一定存在区间,它的mex要改为当前左端点的值。

  1 #include<iostream>
  2 #include<algorithm>
  3 #include<cstring>
  4 #include<cstdio>
  5 #include<sstream>
  6 #include<vector>
  7 #include<stack>
  8 #include<queue>
  9 #include<cmath>
 10 #include<map>
 11 #include<set>
 12 using namespace std;
 13 typedef long long ll;
 14 typedef long long ull;
 15 typedef pair<int,int> pll;
 16 const int INF = 0x3f3f3f3f;
 17 const int maxn = 200000 + 5;
 18 
 19 int n;
 20 int a[maxn];
 21 int Next[maxn];
 22 int mex[maxn];
 23 
 24 int mx[maxn<<2];
 25 ll sum[maxn<<2];
 26 int col[maxn<<2];
 27 
 28 map<int,int> mp;
 29 
 30 void PushUp(int o)
 31 {
 32     sum[o]=sum[o<<1]+sum[o<<1|1];
 33     mx[o]=max(mx[o<<1],mx[o<<1|1]);
 34 }
 35 
 36 void PushDown(int o,int m)
 37 {
 38     if(col[o]!=-1)
 39     {
 40         col[o<<1]=col[o<<1|1]=col[o];
 41         sum[o<<1]=(m-m>>1)*col[o];
 42         sum[o<<1|1]=(m>>1)*col[o];
 43         mx[o<<1]=mx[o<<1|1]=col[o];
 44         col[o]=-1;
 45     }
 46 }
 47 
 48 void build(int l, int r, int o)
 49 {
 50     col[o]=-1;
 51     if(l==r)
 52     {
 53         sum[o]=mx[o]=mex[l];
 54         return;
 55     }
 56     int mid=(l+r)/2;
 57     build(l,mid,o<<1);
 58     build(mid+1,r,o<<1|1);
 59     PushUp(o);
 60 }
 61 
 62 void update(int ql, int qr, int l, int r, int x, int o)
 63 {
 64     if(ql<=l && qr>=r)
 65     {
 66         col[o]=x;
 67         sum[o]=x*(r-l+1);
 68         mx[o]=x;
 69         return;
 70     }
 71     PushDown(o,r-l+1);
 72     int mid=(l+r)/2;
 73     if(ql<=mid)  update(ql,qr,l,mid,x,o<<1);
 74     if(qr>mid)   update(ql,qr,mid+1,r,x,o<<1|1);
 75     PushUp(o);
 76 }
 77 
 78 int query(int l, int r, int x, int o)
 79 {
 80     if(l==r)  return l;
 81     PushDown(o,r-l+1);
 82     int mid=(l+r)/2;
 83     if(mx[o<<1]>x)  return query(l,mid,x,o<<1);
 84     else return query(mid+1,r,x,o<<1|1);
 85 }
 86 
 87 int main()
 88 {
 89     //freopen("in.txt","r",stdin);
 90     while(~scanf("%d",&n) && n)
 91     {
 92         mp.clear();
 93         for(int i=1;i<=n;i++)  scanf("%d",&a[i]);
 94 
 95         //计算出mex(1~i)的最小非负整数
 96         int tmp=0;
 97         for(int i=1;i<=n;i++)
 98         {
 99             mp[a[i]]=1;
100             while(mp.find(tmp)!=mp.end())  tmp++;
101             mex[i]=tmp;
102         }
103 
104         build(1,n,1);
105 
106         //计算出a[i]的下一次出现的位置next值
107         mp.clear();
108         for(int i=n;i>=1;i--)
109         {
110             if(mp.find(a[i])==mp.end())  Next[i]=n+1; 
111             else Next[i]=mp[a[i]];
112             mp[a[i]]=i;
113         }
114 
115         ll ans=0;
116         for(int i=1;i<=n;i++)
117         {
118             ans+=sum[1];
119             if(mx[1]>a[i])
120             {
121                 int l=query(1,n,a[i],1);
122                 int r=Next[i];
123                 if(l<r) update(l,r-1,1,n,a[i],1);
124             }
125             update(i,i,1,n,0,1);
126         }
127         printf("%lld
",ans);
128     }
129 }
原文地址:https://www.cnblogs.com/zyb993963526/p/7210089.html