http://acm.hdu.edu.cn/showproblem.php?pid=4747
题意:
给出一段数据,求出所有区间的Mex和。
思路:
这道题目很不错,参考了大神博客http://www.cnblogs.com/Griselda/archive/2013/11/20/3433595.html。
先计算出mex【i】(表示1~i之间的mex),这肯定是非递减的。
接下来每次将最左边的数删掉,删掉这个数之后,有些区间可能会受到影响,有些不会,那么哪些会受到影响呢?
先找到该数的下一个出现位置next,那么在next之前的区间就有可能会受到影响,而next后面的就一定不会受到影响,因为next【i】可以代替删去的那个数了。
线段树的做法是这样的:
sum【】保存的是【当前左端点,>=左端点】区间的mex之和,那么sum【1】表示的就是 当前左端点到其余各个点的mex之和。(比如左端点为2,那么sum【1】=mex(2,2)+mex(2,3)+mex(2,4)+...+mex(2,n))。
mx【】保存的是当前这些区间内的最大mx值。
每次删掉左端点之后,我们就要去判断接下来的区间中mex值是否会有变化,如果mx【1】大于了当前左端点,那么一定存在区间,它的mex要改为当前左端点的值。
1 #include<iostream> 2 #include<algorithm> 3 #include<cstring> 4 #include<cstdio> 5 #include<sstream> 6 #include<vector> 7 #include<stack> 8 #include<queue> 9 #include<cmath> 10 #include<map> 11 #include<set> 12 using namespace std; 13 typedef long long ll; 14 typedef long long ull; 15 typedef pair<int,int> pll; 16 const int INF = 0x3f3f3f3f; 17 const int maxn = 200000 + 5; 18 19 int n; 20 int a[maxn]; 21 int Next[maxn]; 22 int mex[maxn]; 23 24 int mx[maxn<<2]; 25 ll sum[maxn<<2]; 26 int col[maxn<<2]; 27 28 map<int,int> mp; 29 30 void PushUp(int o) 31 { 32 sum[o]=sum[o<<1]+sum[o<<1|1]; 33 mx[o]=max(mx[o<<1],mx[o<<1|1]); 34 } 35 36 void PushDown(int o,int m) 37 { 38 if(col[o]!=-1) 39 { 40 col[o<<1]=col[o<<1|1]=col[o]; 41 sum[o<<1]=(m-m>>1)*col[o]; 42 sum[o<<1|1]=(m>>1)*col[o]; 43 mx[o<<1]=mx[o<<1|1]=col[o]; 44 col[o]=-1; 45 } 46 } 47 48 void build(int l, int r, int o) 49 { 50 col[o]=-1; 51 if(l==r) 52 { 53 sum[o]=mx[o]=mex[l]; 54 return; 55 } 56 int mid=(l+r)/2; 57 build(l,mid,o<<1); 58 build(mid+1,r,o<<1|1); 59 PushUp(o); 60 } 61 62 void update(int ql, int qr, int l, int r, int x, int o) 63 { 64 if(ql<=l && qr>=r) 65 { 66 col[o]=x; 67 sum[o]=x*(r-l+1); 68 mx[o]=x; 69 return; 70 } 71 PushDown(o,r-l+1); 72 int mid=(l+r)/2; 73 if(ql<=mid) update(ql,qr,l,mid,x,o<<1); 74 if(qr>mid) update(ql,qr,mid+1,r,x,o<<1|1); 75 PushUp(o); 76 } 77 78 int query(int l, int r, int x, int o) 79 { 80 if(l==r) return l; 81 PushDown(o,r-l+1); 82 int mid=(l+r)/2; 83 if(mx[o<<1]>x) return query(l,mid,x,o<<1); 84 else return query(mid+1,r,x,o<<1|1); 85 } 86 87 int main() 88 { 89 //freopen("in.txt","r",stdin); 90 while(~scanf("%d",&n) && n) 91 { 92 mp.clear(); 93 for(int i=1;i<=n;i++) scanf("%d",&a[i]); 94 95 //计算出mex(1~i)的最小非负整数 96 int tmp=0; 97 for(int i=1;i<=n;i++) 98 { 99 mp[a[i]]=1; 100 while(mp.find(tmp)!=mp.end()) tmp++; 101 mex[i]=tmp; 102 } 103 104 build(1,n,1); 105 106 //计算出a[i]的下一次出现的位置next值 107 mp.clear(); 108 for(int i=n;i>=1;i--) 109 { 110 if(mp.find(a[i])==mp.end()) Next[i]=n+1; 111 else Next[i]=mp[a[i]]; 112 mp[a[i]]=i; 113 } 114 115 ll ans=0; 116 for(int i=1;i<=n;i++) 117 { 118 ans+=sum[1]; 119 if(mx[1]>a[i]) 120 { 121 int l=query(1,n,a[i],1); 122 int r=Next[i]; 123 if(l<r) update(l,r-1,1,n,a[i],1); 124 } 125 update(i,i,1,n,0,1); 126 } 127 printf("%lld ",ans); 128 } 129 }