线段树学习笔记

(未完待续)

推荐参考: notonlysuccess神犇的线段树总结

http://blog.csdn.net/kzzhr/article/details/10813301

(1)单点更新

HDU 1166   区间和

HDU 1754   区间最值

HDU 1394   区间和

HDU 2795   区间最值

常用模板:

Struct treetype
{
    int l,int r,int dat;
}t[3000];                //如果要对长度为n的区间建线段树,那么t数组至少要打到3*n

建树:

void build(int l,int r,int o)
{
    if (o>num) num=o;            //num:计数用,记录线段树上已有多少节点
    t[o].l=l;   t[o].r=r;
    if (l==r)
        t[o].dat=a[l];        //更新只包含一个点的区间(即树上的叶子节点)
    else
    {
        int mid=(l+r)/2;
        build(l,mid,2*o);
        build(mid+1,r,2*o+1);
        t[o].dat=max/min/sum(t[2*o].dat,t[2*o+1].dat);        //根据要求而定
    }
}

更新节点:

  1. 对于求区间和的问题:

      更新:令a[x]=m;

void update(int l,int r,int o)        //a[x]=m;
{
        t[o].l=l;   t[o].r=r;
        if ((l==x)&&(r==x))                //找到点节点:更新点
            t[o].dat=m;
        else
        {
            int mid=(l+r)/2;
            if (x<=mid)            //否则:对左or右区间进行处理
            {
                int tmp=t[2*o].dat;
                update(l,mid,2*o);
                t[o].dat+=t[2*o].dat-tmp;
            }
            else
            {
                int tmp=t[2*o+1].dat;
                update(mid+1,r,2*o+1);
                t[o].dat+=t[2*o+1].dat-tmp;
            }
        }
}

    2. 对于求区间最值的问题:

      更新:令a[tx]=ty;

void update(int l,int r,int o)
{
    if (o>num) return;            //防止出现o不停增加的死循环
                                      (其实加不加也无所谓-_-||if ((l==tx)&&(r==tx))            //找到点节点:更新点
        t[o].dat=ty;
    else
    {
        int mid=(l+r)/2;
        if (tx<=mid)            //否则:对左or右区间进行处理
        {
            update(l,mid,2*o);
            t[o].dat=max(t[o].dat,t[2*o].dat);
        }
        else
        {
            update(mid+1,r,2*o+1);
            t[o].dat=max(t[o].dat,t[2*o+1].dat);
        }
    }
}

求区间上的某种性质:

  1. 求区间和
int query_sum(int l,int r,int o)
{
    if (l>r)                //不可行的情况
        return 0;
    else if (l==r)
        return a[l];
    else
    {
        int tl=t[o].l,tr=t[o].r;
        if ((l==tl)&&(r==tr))        //正好命中线段树上的一整块区间,直接返回值即可
            return t[o].dat;
        else
        {
            int mid=(tl+tr)/2;        //否则拆开:l...mid和mid+1...r
            if (r<=mid)                   //注意红字部分!
                return (query_sum(l,r,2*o));
            else if (l>mid)
                return (query_sum(l,r,2*o+1));
            else
                return (query_sum(l,mid,2*o)+query_sum(mid+1,r,2*o+1));
        }
    }
}

    2.求区间最值(以最大值为例):

int query_max(int l,int r,int o)
{
    if (o>num) return 0;            //不可行的情况
    int tl=t[o].l,tr=t[o].r;
    if ((l==tl)&&(r==tr))        //正好命中线段树上的一整块区间,直接返回值即可
        return t[o].dat;
    else
    {
        int mid=(tl+tr)/2;            //否则拆开:l...mid和mid+1...r
        if (r<=mid)                    
            return (query_max(l,r,2*o));
        else if (l>mid)
            return (query_max(l,r,2*o+1));
        else if ((l<=mid)&&(mid<=r))
        {
            int tm1=query_max(l,mid,2*o);
            int tm2=query_max(mid+1,r,2*o+1);
            return max(tm1,tm2);
        }
    }
}    

    主程序:

cin>>n;
for i:=1 to n do cin>>a[i];
build(1,n,1);                                //建树

for i:=1 to m do
{
        cin>>tx>>ty;
        a[tx]=ty;
        update(1,n,1);                        //单点更新
}

cin>>ml>>mr;
ans1=query_sum(ml,mr,1);                //求区间和
ans2=query_max(ml,mr,1);                //求区间最值

备注:模板中的除法操作还可以用位运算优化:int mid=(l+r)>>1;

附录:网上的一段模板(hdu 1166)

 1 #include<stdio.h>
 2 #define MAX 50000+10
 3 int head[MAX];
 4 struct node
 5 {
 6     int l,r;
 7     int value;
 8 }tree[3*MAX];
 9 
10 void build(int l,int r,int v)            //建树
11 {
12     tree[v].l=l;
13     tree[v].r=r;
14     if(l==r)
15     {
16         tree[v].value=head[l];
17         return ;
18     }
19     int mid=(l+r)>>1;
20     build(l,mid,v*2);
21     build(mid+1,r,v*2+1);
22     tree[v].value=tree[v+v].value+tree[v+v+1].value;
23 }
24 
25 int queue(int a,int b,int v)            //求和
26 {
27     if(tree[v].l==a&&tree[v].r==b)
28     {
29         return tree[v].value;
30     }
31     int mid=(tree[v].l+tree[v].r)>>1;
32     if(b<=mid) return queue(a,b,v+v);
33     else if(a>mid) return queue(a,b,v+v+1);
34     else return queue(a,mid,v*2)+queue(mid+1,b,v*2+1);
35 }
36 
37 void update(int a,int b,int v)        //更新: head[a]+=b;
38 {
39     tree[v].value+=b;
40     if(tree[v].l==tree[v].r)
41     {
42         return ;
43     }
44     int mid=(tree[v].l+tree[v].r)>>1;
45     if(a<=mid) update(a,b,v*2);
46     else update(a,b,v*2+1);
47 }
48 
49 int main()
50 {
51     int T,N;
52     char str[10];
53     int a,b;
54     while(scanf("%d",&T)>0)
55     {
56         for(int h=0;h<T;h++)
57         {
58            scanf("%d",&N);
59            for(int i=1;i<=N;i++)
60            {
61                scanf("%d",&head[i]);
62            }
63            build(1,N,1);
64            printf("Case %d:
",h+1);
65            while(scanf("%s",str)>0)
66            {
67                if(str[0]=='A')
68                {
69                    scanf("%d %d",&a,&b);
70                    update(a,b,1);
71                }
72                else if(str[0]=='S')
73                {
74                     scanf("%d %d",&a,&b);
75                     update(a,-b,1);
76                }
77                else if(str[0]=='Q')
78                {
79                    scanf("%d %d",&a,&b);
80                     int ans=queue(a,b,1);
81                     printf("%d
",ans);
82                }
83                else
84                {
85                    break;
86                }
87            }
88 
89         }
90     }
91 }
View Code

---------------------------------------------------

2.成段更新

POJ3468            更新+求和

HDU1698            模板题-_-||

AHOI 2009          变形:加法+乘法混合

POJ 2528        +离散化

struct
{
    int l,r;
    __int64 dat;
}t[300010];

char ch;
__int64 ml,mr,md;
__int64 add[300010];
    
    建树:
void build(long l,long r,long long o)
{
    if (o>num) num=o;
    t[o].l=l;   t[o].r=r;
    if (l==r)
        t[o].dat=a[l];
    else
    {
        long mid=(l+r)/2;
        build(l,mid,2*o);
        build(mid+1,r,2*o+1);
        t[o].dat=t[2*o].dat+t[2*o+1].dat;
    }
}

更新:cin>>ml>>mr>>md;    update(ml,mr,1);        ->令a[ml..mr]+=md
void update(long l,long r,long long o)
{
    if (o>num) return;
    long tl=t[o].l,tr=t[o].r;
    if ((tl==l)&&(tr==r))
    {
        t[o].dat+=(tr-tl+1)*md;
        add[o]+=md;
        return;
    }
    else
    {
        long mid=(tl+tr)/2;
        //if (tl==tr) return;
        push_down(o,tl,mid,tr);
        if (r<=mid)
            update(l,r,2*o);
        else if (l>mid)
            update(l,r,2*o+1);
        else
        {
            update(l,mid,2*o);
            update(mid+1,r,2*o+1);
        }
        t[o].dat=t[2*o].dat+t[2*o+1].dat;
    }
}

push_down:(向下更新一层add数组)
void push_down(long long o,long l,long mid,long r)
{
    if (add[o]!=0)
    {
        long rl=2*o,rr=2*o+1;
        __int64 tm=add[o];
        add[rl]+=tm;
        add[rr]+=tm;
        t[rl].dat+=tm*(mid-l+1);
        t[rr].dat+=tm*(r-mid);
        add[o]=0;
    }
}

    求和:cin>>ml>>mr>>md;    query_sum(ml,mr,1);            ->求和:a[ml...mr]
__int64 query_sum(long l,long r,long long o)
{
    if (o>num) return 0;
    long tl=t[o].l,tr=t[o].r;
    if ((tl==l)&&(tr==r))
        return t[o].dat;
    else
    {
        long mid=(tl+tr)/2;
        push_down(o,tl,mid,tr);
        if (r<=mid)
            return query_sum(l,r,2*o);
        else if (l>mid)
            return query_sum(l,r,2*o+1);
        else
            return (query_sum(l,mid,2*o)+query_sum(mid+1,r,2*o+1));
    }
}

加离散化:(POJ 2528为例)

我们要重新给压缩后的线段标记起点和终点。
按照通用的离散化方法。。。。
首先依次读入线段端点坐标,存于post[MAXN][2]中,post[i][0]存第一条线段的起点,post[i][1]存第一条线段的终点,然后用一个结构题数组line[MAXN]记录信息,line[i].li记录端点坐标,line[i].num记录这个点属于哪条线段(可以用正负数表示,负数表示起点,正数表示终点)。假如有N条线段,就有2*N个端点。然后将line数组排序,按照端点的坐标,从小到大排。接着要把线段赋予新的端点坐标了。从左到右按照递增的次序,依次更新端点,假如2*N个点中,共有M个不同坐标的点,那么线段树的范围就是[1,M]。

        memset(v,false,sizeof(v));
        memset(t,0,sizeof(t));
        num=0;

        cin>>n;
        for (int i=1;i<=n;i++)
        {
            cin>>ml>>mr;
            sm[2*i-1].nm=ml;
            sm[2*i].nm=mr;
            sm[2*i-1].ky=i;     //start: >0
            sm[2*i].ky=-i;      //end:   <0
        }
        isort(1,2*n);    //sort: based on sm[i].nm

        tmp=0;
        for (int i=1;i<=2*n;i++)
        {
            if (sm[i].nm!=sm[i-1].nm) tmp++;
            if ((sm[i].nm-sm[i-1].nm==2)&&(i>1)) tmp++;
            int tkey=sm[i].ky;
            if (tkey>0)
                sgm[tkey].st=tmp;
            else if (tkey<0)
                sgm[-tkey].ed=tmp;
        }

        build(1,tmp,1);

        for (int i=1;i<=n;i++)
        {
            ml=sgm[i].st;   mr=sgm[i].ed;   md=i;
            update(ml,mr,1);        //a[ml..mr]=md;
        }

        ans=0;
        sum(1);
        cout<<ans<<endl;

AHOI2009  Seq

  1 AHOI 2009  seq:
  2 #include <iostream>
  3 #include <cstdio>
  4 #include <cstring>
  5 using namespace std;
  6 
  7 struct
  8 {
  9     int l,r;
 10     long long dat;
 11 }t[300010];
 12 
 13 long long a[100010],add[300010],mul[300010];
 14 int n,m,ml,mr,opr;
 15 long long md,p,tmp,num,ans;
 16 
 17 void modp(long long *nm)
 18 {
 19     *nm=((*nm)%p);
 20 }
 21 
 22 void build(int l,int r,long long o)
 23 {
 24     if (o>num) num=o;
 25     mul[o]=1;   add[o]=0;
 26     t[o].l=l;   t[o].r=r;
 27     if (l==r)
 28         t[o].dat=a[l];
 29     else
 30     {
 31         int mid=(l+r)/2;
 32         build(l,mid,2*o);
 33         build(mid+1,r,2*o+1);
 34         t[o].dat=t[2*o].dat+t[2*o+1].dat;
 35         modp(&t[o].dat);
 36     }
 37 }
 38 /*
 39 void push_down1(long long o,int l,int mid,int r)  //a[i]=a[i]*c
 40 {
 41     if (mul[o]!=1)
 42     {
 43         long long tmp1=mul[o];
 44         mul[2*o]=mul[2*o]*tmp1;
 45         mul[2*o+1]=mul[2*o+1]*tmp1;
 46         t[2*o].dat=t[2*o].dat*tmp1;
 47         t[2*o+1].dat=t[2*o+1].dat*tmp1;
 48         modp(&mul[2*o]);
 49         modp(&mul[2*o+1]);
 50         modp(&t[2*o].dat);
 51         modp(&t[2*o+1].dat);
 52         mul[o]=1;
 53     }
 54 }
 55 
 56 void push_down2(long long o,int l,int mid,int r)  //a[i]=a[i]+c;
 57 {
 58     if (add[o]!=0)
 59     {
 60         long long tmp2=add[o];
 61         add[2*o]+=tmp2;
 62         add[2*o+1]+=tmp2;
 63         t[2*o].dat+=tmp2*(mid-l+1);
 64         t[2*o+1].dat+=tmp2*(r-mid);
 65         modp(&add[2*o]);
 66         modp(&add[2*o+1]);
 67         modp(&t[2*o].dat);
 68         modp(&t[2*o+1].dat);
 69         add[o]=0;
 70     }
 71 }
 72 */
 73 void push_down(long long o,int l,int mid,int r)
 74 {
 75     t[2*o].dat=(t[2*o].dat*mul[o]+(mid-l+1)*add[o])%p;
 76     t[2*o+1].dat=(t[2*o+1].dat*mul[o]+(r-mid)*add[o])%p;
 77     
 78     mul[2*o]=(mul[2*o]*mul[o])%p;
 79     mul[2*o+1]=(mul[2*o+1]*mul[o])%p;
 80     
 81     add[2*o]=(add[2*o]*mul[o]+add[o])%p;
 82     add[2*o+1]=(add[2*o+1]*mul[o]+add[o])%p;
 83     
 84     mul[o]=1;    add[o]=0;
 85 }
 86 
 87 long long query_sum(int l,int r,long long o)
 88 {
 89     if (o>num) return 0;
 90     int tl=t[o].l,tr=t[o].r;
 91     if ((tl==l)&&(tr==r))
 92         return t[o].dat;
 93     else
 94     {
 95         int mid=(tl+tr)/2;
 96 
 97         //if (opr==1)
 98         //    push_down1(o,tl,mid,tr);
 99         //else if (opr==2)
100         //    push_down2(o,tl,mid,tr);
101         push_down(o,tl,mid,tr);
102 
103         if (r<=mid)
104         {
105             tmp=query_sum(l,r,2*o);
106             modp(&tmp);
107             return tmp;
108         }
109         else if (l>mid)
110         {
111             tmp=query_sum(l,r,2*o+1);
112             modp(&tmp);
113             return tmp;
114         }
115         else
116         {
117             long long t1=query_sum(l,mid,2*o);
118             long long t2=query_sum(mid+1,r,2*o+1);
119             tmp=(t1%p)+(t2%p);
120             modp(&tmp);
121             return tmp;
122         }
123     }
124 }
125 
126 void update(int l,int r,long long o)
127 {
128     if (o>num) return;
129     int tl=t[o].l,tr=t[o].r;
130     if ((tl==l)&&(tr==r))
131     {
132         /*
133         if (opr==2)
134         {
135             t[o].dat+=(tr-tl+1)*md;
136             add[o]+=md;
137             modp(&t[o].dat);
138             modp(&add[o]);
139         }
140         else if (opr==1)
141         {
142             t[o].dat=t[o].dat*md;
143             mul[o]=mul[o]*md;
144             modp(&t[o].dat);
145             modp(&mul[o]);
146         }
147         */
148         if (opr==1)
149         {
150             t[o].dat=(t[o].dat*md)%p;
151             mul[o]=(mul[o]*md)%p;
152             add[o]=(add[o]*md)%p;
153         }
154         else if (opr==2)
155         {
156             t[o].dat=(t[o].dat+(tr-tl+1)*md)%p;
157             add[o]=(add[o]+md)%p;
158         }
159         return;
160     }
161     else
162     {
163         int mid=(tl+tr)/2;
164         //if (tl==tr) return;
165 
166         push_down(o,tl,mid,tr);
167 
168         if (r<=mid)
169             update(l,r,2*o);
170         else if (l>mid)
171             update(l,r,2*o+1);
172         else
173         {
174             update(l,mid,2*o);
175             update(mid+1,r,2*o+1);
176         }
177         t[o].dat=(t[2*o].dat+t[2*o+1].dat)%p;
178     }
179 }
180 
181 void debug()
182 {
183     cout<<"This is the beginning"<<endl;
184     for (int i=1;i<=num;i++)
185         cout<<i<<" | "<<t[i].l<<" "<<t[i].r<<" "<<t[i].dat<<" |-and-| add="<<add[i]<<"  mul="<<mul[i]<<endl;
186     cout<<"This is the ending"<<endl;
187 }
188 
189 void debug1()
190 {
191     if (opr==3)
192     {
193         cout<<"The correct ans should be : ";
194         long long aw=0;
195         for (int i=ml;i<=mr;i++)
196             aw+=a[i];
197         cout<<aw<<endl;
198     }
199     else
200     {
201         if (opr==1)
202         {
203             for (int i=ml;i<=mr;i++)
204                 a[i]=a[i]*md;
205         }
206         else if (opr==2)
207         {
208             for (int i=ml;i<=mr;i++)
209                 a[i]=a[i]+md;
210         }
211         cout<<"debug sequence: ";
212         for (int i=1;i<=n;i++)
213             cout<<a[i]<<" ";
214         cout<<endl;
215     }
216 }
217 
218 int main()
219 {
220     freopen("seq.in","r",stdin);
221     freopen("seq.out","w",stdout);
222     
223     scanf("%d %lld",&n,&p);
224     for (int i=1;i<=n;i++)
225         cin>>a[i];
226 
227     memset(t,0,sizeof(t));
228     build(1,n,1);
229     //memset(add,0,sizeof(add));
230     //memset(mul,0,sizeof(mul));
231 //    debug();
232 
233     cin>>m;
234     for (int i=1;i<=m;i++)
235     {
236         cin>>opr;
237         if (opr==3)
238         {
239             cin>>ml>>mr;
240             ans=query_sum(ml,mr,1);
241             //debug1();
242             ans=ans%p;
243             cout<<ans<<endl;
244         }
245         else
246         {
247             cin>>ml>>mr>>md;
248             update(ml,mr,1);
249             //debug1();
250         }
251     }
252 }
View Code
原文地址:https://www.cnblogs.com/pdev/p/3955419.html