HDU 4638 树状数组 想法题

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4638

解题思路:

题意为询问一段区间里的数能组成多少段连续的数。先考虑从左往右一个数一个数添加,考虑当前添加了i - 1个数的答案是x,那么添加完i个数后的答案是多少?可以看出,是根据a[i]-1和a[i]+1是否已经添加而定的,如果a[i]-1或者a[i]+1已经添加一个,则段数不变,如果都没添加则段数加1,如果都添加了则段数减1。设v[i]为加入第i个数后的改变量,那么加到第x数时的段数就是sum{v[i]} (1<=i<=x}。仔细想想,若删除某个数,那么这个数两端的数的改变量也会跟着改变,这样一段区间的数构成的段数就还是他们的v值的和。将询问离线处理,按左端点排序后扫描一遍,左边删除,右边插入,查询就是求区间和。

以上摘自杭电的解题报告

标程:标程是从左边插入,从右边删除

 1 #include<cstdio>
 2 #include<cstring>
 3 #include<algorithm>
 4 #include<iostream>
 5 using namespace std;
 6 const int N = 100005;
 7 struct Q
 8 {
 9     int l,r,id;
10 }query[N];
11 bool cmp(Q x,Q y)
12 {
13     return x.r>y.r;
14 }
15 int a[N],c[N],d[N];
16 int lowbit(int x)
17 {
18     return x&(-x);
19 }
20 void up(int x,int v)
21 {
22     while(x<N)
23     {
24         c[x]+=v;
25         x+=lowbit(x);
26     }
27 }
28 int getsum(int x)
29 {
30     int r=0;
31     while(x>0)
32     {
33         r+=c[x];
34         x-=lowbit(x);
35     }
36     return r;
37 }
38 bool u[N];
39 int ret[N],ps[N];
40 int main()
41 {
42     int T,n,m,i,j;
43     scanf("%d",&T);
44     while(T--)
45     {
46         scanf("%d%d",&n,&m);
47         for(i=1;i<=n;i++)
48         {
49             scanf("%d",&a[i]);
50             ps[a[i]]=i;
51         }
52         memset(c,0,sizeof(c));
53         memset(d,0,sizeof(d));
54         memset(u,0,sizeof(u));
55         for(i=n;i>0;i--)
56         {
57             if(u[a[i]-1])d[i]++;
58             if(u[a[i]+1])d[i]++;
59             if(d[i]==0)
60             {
61                 up(i,1);
62             }
63             else if(d[i]==2)
64             {
65                 up(i,-1);
66             }
67             u[a[i]]=true;
68         }
69         for(i=0;i<m;i++)
70         {
71             scanf("%d%d",&query[i].l,&query[i].r);
72             query[i].id=i;
73         }
74         sort(query,query+m,cmp);
75         int j=n;
76         for(i=0;i<m;i++)
77         {
78             while(j>query[i].r)
79             {
80                 if(a[j]>1&&ps[a[j]-1]<j)
81                 {
82                     d[ps[a[j]-1]]--;
83                     up(ps[a[j]-1],1);
84                 }
85                 if(a[j]<n&&ps[a[j]+1]<j)
86                 {
87                     d[ps[a[j]+1]]--;
88                     up(ps[a[j]+1],1);
89                 }
90                 j--;
91             }
92             ret[query[i].id]=getsum(query[i].r)-getsum(query[i].l-1);
93         }
94         for(i=0;i<m;i++)printf("%d
",ret[i]);
95     }
96     return 0;
97 }
View Code

 我觉得看完解题思路后我还是不太明白,就模拟了一下样例。

3 1 2 5 4

得到的数组v 为 1 1 -1 1 0

这时查询任意1-k区间的段数=sum(vi)(1=<i<=k),都是对的···神奇···

在删除3后,3-1所在的位置为3,即v[3]=v[3]+1,3+1所在的位置为5。++v[5]。

v数组变为 1  1 0 1 0 ,查询2-i区间的段数=sum(vi)(2=<i<=k),同样是对的···更神奇···有兴趣的可以自己证明

我的代码:

 1 #include <cstdio>
 2 #include <algorithm>
 3 #define N 100005
 4 using namespace std;
 5 struct Node
 6 {
 7     int l,r,index;
 8 } p[N];
 9 int a[N],rank[N],ans[N],c[N];
10 int n,m;
11 
12 bool cmp(Node a,Node b)
13 {
14     return a.l < b.l;
15 }
16 int lowbit(int x)
17 {
18     return x&(-x);
19 }
20 void add(int x,int v)
21 {
22     while(x<N)
23     {
24         c[x] += v;
25         x += lowbit(x);
26     }
27 }
28 int sum(int x)
29 {
30     int s =0;
31     while(x)
32     {
33         s += c[x];
34         x -= lowbit(x);
35     }
36     return s;
37 }
38 
39 void init()
40 {
41     scanf("%d%d",&n,&m);
42     rank[0] =rank[n+1] = n+1;
43     for(int i=1; i<=n; ++i)
44     {
45         scanf("%d",&a[i]);
46         rank[a[i]] = i;
47     }
48     for(int i=0; i<m; ++i)
49     {
50         scanf("%d%d",&p[i].l,&p[i].r);
51         p[i].index = i;
52     }
53 }
54 int main()
55 {
56 //    freopen("in.cpp","r",stdin);
57     int t;
58     scanf("%d",&t);
59     while(t--)
60     {
61         init();
62         memset(c, 0, sizeof(c) );
63         for(int i=1; i<=n; ++i)
64         {
65             int d =0;
66             if(rank[a[i]+1] < i ) ++d;
67             if(rank[a[i]-1] < i ) ++d;
68             if(d == 0 )
69                 add(i,1);
70             else if(d == 2)
71                 add(i,-1);
72         }
73         sort(p,p+m,cmp);
74         rank[0] = rank[n+1] = 0;
75         int j=1;
76         for(int i=0; i<m; ++i)
77         {
78             int t = p[i].l;
79             while(j < t)
80             {
81                 if(rank[a[j]+1] > j)
82                     add(rank[a[j]+1],1);
83                 if(rank[a[j]-1] > j)
84                     add(rank[a[j]-1],1);
85                 ++j;
86             }
87             ans[p[i].index] = sum(p[i].r)-sum(p[i].l-1);
88         }
89         for(int i=0; i<m; ++i)
90             printf("%d
",ans[i]);
91     }
92     return 0;
93 }
View Code
原文地址:https://www.cnblogs.com/allh123/p/3238817.html