线段树

线段树的应用:

  线段树主要用来维护一些有关于区间的问题,比如说区间的最值,区间和等一系列满足结合律的问题。

   满足结合律是指这个大区间的答案是由其中的许多小区间的答案组合而成,比如说最大值,这个区间的最大值就是其中的小区间中的所有值得最大值。

  对于线段树来说,代码量比较长,不易于实现,而且所需空间也比较大,但是比较高效。

线段树的模板:

1.单点修改,区间查询。

     以求区间求和为例:https://loj.ac/problem/130

     

#include<bits/stdc++.h>
using namespace std;
#define register int
#define ll long long
#define INF 0x3f3f3f3f 
#define maxn 1000009
#define maxm
inline ll read()
{
    ll x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ll)(ch-'0');ch=getchar();}
    return x*f;
}
ll sum[maxn<<2],val[maxn<<2];
ll n,m,k,ans,tot;

void built(int p,int l,int r)
{
    if(l==r)
    {
        sum[p]=val[l];
        return ;
    }
    int mid=(l+r)>>1;
    built(p<<1,l,mid);
    built((p<<1)+1,mid+1,r);
    sum[p]=sum[p<<1]+sum[(p<<1)+1];
//    cout<<"nice "<<p<<" "<<l<<" "<<r<<" "<<sum[p]<<endl;
}

void Update(int p,int l,int r,ll now,ll k)
{
    if(l==r&&now==l)
    {
        sum[p]+=k;
        return ;
    }
    int mid=(l+r)>>1;
    if(now<=mid)
        Update(p<<1,l,mid,now,k);
    else
        Update((p<<1)+1,mid+1,r,now,k);
    sum[p]=sum[p<<1]+sum[(p<<1)+1];
    
}

ll Query(int p,int l,int r,ll nl,ll nr)
{
    if(nl<=l&&r<=nr)
        return sum[p]; 
    ll res=0;
    int mid=(l+r)>>1;
    if(nl<=mid)
        res+=Query(p<<1,l,mid,nl,nr);
    if(nr>mid)
        res+=Query((p<<1)+1,mid+1,r,nl,nr);    
    //cout<<p<<" "<<l<<" "<<r<<" "<<nl<<" "<<nr<<" "<<res<<endl;
    return res;
}
int main()
{
//    freopen(".in","r",stdin);
//    freopen(".out","w",stdout);
    n=read(),m=read();
    for(int i=1;i<=n;i++)
        val[i]=read();
    built(1,1,n);
    for(int i=1;i<=m;i++)
    {
        int opt=read();
        ll a=read(),b=read();
        if(opt==1)
            Update(1,1,n,a,b);
        else
            printf("%lld
",Query(1,1,n,a,b));
    }
    fclose(stdin);
    fclose(stdout);
    return 0;
}

 2.区间修改,区间查询。

    以区间加法,区间求和为例:https://www.luogu.org/problemnew/show/P3372

    其实与上一个模板相比只是多了一个add数组需要维护而已,不过要记牢Update和Query的时候要将add标记(也称Lazy标记)下传给子区间。

 

#include<bits/stdc++.h>
using namespace std;
#define register int
#define ll long long
#define INF 0x3f3f3f3f 
#define maxn 1000009
#define maxm
inline ll read()
{
    ll x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ll)(ch-'0');ch=getchar();}
    return x*f;
}
ll sum[maxn<<2],val[maxn<<2],add[maxn<<2];
ll n,m,k,ans,tot;
#define ls(p) p<<1
#define rs(p) p<<1|1 

void push_up(ll p)
{
    sum[p]=sum[ls(p)]+sum[rs(p)];
}

void built(ll p,int l,ll r)
{
    if(l==r)
    {
        sum[p]=val[l];
        return ;
    }
    ll mid=(l+r)>>1;
    built(ls(p),l,mid);
    built(rs(p),mid+1,r);
    push_up(p);
}

void pass(ll p,ll l,ll r,ll k)
{
    add[p]+=k;
    sum[p]+=(r-l+1)*k;
}
void push_down(ll p,ll l,ll r)
{
    ll mid=(l+r)>>1;
    pass(ls(p),l,mid,add[p]);
    pass(rs(p),mid+1,r,add[p]);
    add[p]=0;
}
void Update(ll p,ll l,ll r,ll nl,ll nr,ll k)
{
    if(nl<=l&&r<=nr)
    {
        add[p]+=k;
        sum[p]+=(r-l+1)*k;
        return ;
    }
    push_down(p,l,r);
    ll mid=(l+r)>>1;
    if(nl<=mid)
        Update(ls(p),l,mid,nl,nr,k);
    if(mid<nr)
        Update(rs(p),mid+1,r,nl,nr,k);
    push_up(p); 
}

ll Query(ll p,ll l,ll r,ll nl,ll nr)
{    
    ll res=0;
    if(nl<=l&&r<=nr)
        return sum[p];    
    push_down(p,l,r);
    ll mid=(l+r)>>1;
    if(nl<=mid)
        res+=Query(ls(p),l,mid,nl,nr);
    if(mid<nr)
        res+=Query(rs(p),mid+1,r,nl,nr);
    return res;
}
int main()
{
//    freopen(".in","r",stdin);
//    freopen(".out","w",stdout);
    n=read(),m=read();
    for(int i=1;i<=n;i++)
        val[i]=read();
    built(1,1,n);
    for(int i=1;i<=m;i++)
    {
        ll opt=read(),a=read(),b=read();
        ll c;
        if(opt==1)
            c=read(),Update(1,1,n,a,b,c);
        else
            printf("%lld
",Query(1,1,n,a,b));
    }
    fclose(stdin);
    fclose(stdout);
    return 0;
}

    

    再来一个区间加法,区间乘法,区间求和的例子:https://www.luogu.org/problemnew/show/P3373

   这个和上一个模板又有不同点,这次需要将区间都乘以一个数,因此我们需要维护一个mul数组,表示这个区间的乘法标记,同时也还需要维护add,sum数组。

由于乘法的出现,add在更新以及下传的过程中将会与以往不同。因为乘法的优先级高于加法,所以add必须先乘以所要乘的数k,再加上所要加的数kk,举个例子吧:

假设当前区间的和为X,所要乘的数为K,所要加的数为KK,这个区间本来的mul标记的值为a,add标记的值为b,那么新的区间和sum=K(a*X+b)+KK,将其展开便得到:

新的乘法标记为Ka,加法标记为Kb+KK。

写法1:

#include<bits/stdc++.h>
using namespace std;
#define ls(p) p<<1
#define rs(p) p<<1|1
#define maxn 200009
#define ll long long
ll n,m,ans,tot,base;
ll val[maxn],sum[maxn<<2],mul[maxn<<2],add[maxn<<1];

inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch-'0');ch=getchar();}
    return x*f;
}

void push_up(ll p)
{
    sum[p]=(sum[ls(p)]+sum[rs(p)])%base;
}


void pass(ll p,ll l,ll r,ll mu,ll ad)
{
    sum[p]=(sum[p]*mu)%base;
    sum[p]=(sum[p]+(r-l+1)*ad)%base;
    mul[p]=(mul[p]*mu)%base;
    add[p]=(add[p]*mu)%base;
    add[p]=(add[p]+ad)%base;
    
}
void push_down(ll p,ll l,ll r)
{
    ll mid=(l+r)>>1;
    pass(ls(p),l,mid,mul[p],add[p]);
    pass(rs(p),mid+1,r,mul[p],add[p]);
    mul[p]=1;
    add[p]=0;    
}

void built(ll p,ll l,ll r)
{
    add[p]=0;
    mul[p]=1;
    if(l==r)
    {
        sum[p]=val[l];
        return ;
    }
    ll mid=(l+r)>>1;
    built(ls(p),l,mid);
    built(rs(p),mid+1,r);
    push_up(p);
}

void Update_mul(ll nl,ll nr,ll l,ll r,ll p,ll k)
{
    if(nl<=l&&r<=nr)
    {
        mul[p]=(mul[p]*k)%base;
        add[p]=(add[p]*k)%base;
        sum[p]=(sum[p]*k)%base;
        return ;
    }
    if(mul[p]!=1||add[p])
        push_down(p,l,r);
    ll mid=(l+r)>>1;
    if(nl<=mid)
        Update_mul(nl,nr,l,mid,ls(p),k);
    if(nr>mid)
        Update_mul(nl,nr,mid+1,r,rs(p),k);
    push_up(p);
}

void Update_add(ll nl,ll nr,ll l,ll r,ll p,ll k)
{
    if(nl<=l&&r<=nr)
    {
        add[p]=(add[p]+k)%base;
        sum[p]=(sum[p]+(r-l+1)*k)%base;
        return ;
    }
    if(add[p]||mul[p]!=1)
        push_down(p,l,r);
    ll mid=(l+r)>>1;
    if(nl<=mid)
        Update_add(nl,nr,l,mid,ls(p),k);
    if(nr>mid)
        Update_add(nl,nr,mid+1,r,rs(p),k);
    push_up(p);
}

ll Query(ll nl,ll nr,ll l,ll r,ll p)
{
    ll res=0;
    if(nl<=l&&r<=nr)
        return sum[p];
    if(add[p]||mul[p]!=1)
        push_down(p,l,r);
    ll mid=(l+r)>>1;
    if(nl<=mid)
        res+=Query(nl,nr,l,mid,ls(p))%base;
    if(nr>mid)
        res+=Query(nl,nr,mid+1,r,rs(p))%base;
    return res%base;
}
int main()
{
    n=read(),m=read(),base=read();
    for(int i=1;i<=n;i++)
        val[i]=read();
    built(1,1,n);
    for(int i=1;i<=m;i++)
    {
        ll opt,x,y,k;
        opt=read(),x=read(),y=read();
        if(opt==1)
        {
            k=read();
            Update_mul(x,y,1,n,1,k);
        }
        else if(opt==2)
        {
            k=read();
            Update_add(x,y,1,n,1,k);
        }
        else 
        {
            printf("%lld
",Query(x,y,1,n,1)%base);
        }
    }
    return 0;
}

写法2:

#include<bits/stdc++.h>
using namespace std;
#define re register int
#define ll long long
#define INF 0x3f3f3f3f
#define maxn 100009
#define maxm
inline ll read()
{
    ll x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ll)(ch-'0');ch=getchar();}
    return x*f;
}
ll sum[maxn<<2],add[maxn<<2],mul[maxn<<2],val[maxn];
int n,m,k,ans,tot,pp; 

#define ls(p) p<<1
#define rs(p) p<<1|1

void push_up(int p)
{
    sum[p]=(sum[ls(p)]+sum[rs(p)])%pp;
}

void built(int p,int l,int r)
{
    mul[p]=1,add[p]=0; 
    if(l==r)
    {
        sum[p]=val[l];
        return ;
    }
    int mid=(l+r)>>1;
    built(ls(p),l,mid);
    built(rs(p),mid+1,r);
    push_up(p);
}
void pass(int p,int l,int r,ll k,ll kk)
{
    sum[p]=(sum[p]*k)%pp;
    sum[p]=(sum[p]+kk*(r-l+1))%pp;
    mul[p]=(mul[p]*k)%pp;
    add[p]=(add[p]*k)%pp;
    add[p]=(add[p]+kk)%pp;
}

void push_down(int p,int l,int r)
{
    int mid=(l+r)>>1;
    pass(ls(p),l,mid,mul[p],add[p]);
    pass(rs(p),mid+1,r,mul[p],add[p]);
    mul[p]=1,add[p]=0;
}

void Update(int p,int l,int r,int nl,int nr,ll k,ll kk)
{
    if(nl<=l&&r<=nr)
    {
        mul[p]=(mul[p]*k)%pp;
        add[p]=(add[p]*k)%pp;
        add[p]=(add[p]+kk)%pp;
        sum[p]=(sum[p]*k)%pp;
        sum[p]=(sum[p]+(r-l+1)*kk)%pp;
        return ;
    }
    if(add[p]||mul[p]!=1)
        push_down(p,l,r);
    int mid=(l+r)>>1;
    if(nl<=mid)
        Update(ls(p),l,mid,nl,nr,k,kk);
    if(mid<nr)
        Update(rs(p),mid+1,r,nl,nr,k,kk);
    push_up(p);
}


ll Query(int p,int l,int r,int nl,int nr)
{
    ll res=0;
    if(nl<=l&&r<=nr)
        return sum[p];
    if(add[p]||mul[p]!=1)
        push_down(p,l,r);
    int mid=(l+r)>>1;
    if(nl<=mid)
        res+=Query(ls(p),l,mid,nl,nr)%pp;
    if(mid<nr)
        res+=Query(rs(p),mid+1,r,nl,nr)%pp;
    return res%pp; 
}
int main()
{
//    freopen(".in","r",stdin);
//    freopen(".out","w",stdout);
    n=read(),m=read(),pp=read();
    for(int i=1;i<=n;i++)
        val[i]=read();
    built(1,1,n);
//    cout<<sum[1]<<" "<<add[1]<<" "<<mul[1]<<endl;
    for(int i=1;i<=m;i++)
    {
        int opt=read(),x=read(),y=read();
        ll z;
        if(opt==1)
        {
            z=read();
            Update(1,1,n,x,y,z,0);
        }
        if(opt==2)
        {
            z=read();
            Update(1,1,n,x,y,1,z);
         } 
        if(opt==3)
            printf("%lld
",Query(1,1,n,x,y)%pp);
    }
    fclose(stdin);
    fclose(stdout);
    return 0;
}

习题报告:

  对于一些线段树的应用还是比较灵活的,在使用的时候要注意变通。

   L1198 最大数:https://www.luogu.org/problemnew/show/P1198

   解题思路:由于线段树不支持在线添加,我们可以稍微改变一下思路,如果是要在序列尾部添加数的话,也就相当于修改序列尾部的空值。

所以此题只需要将n以后的点合理赋值再进行单点修改,区间查询的操作,不就变成了模板1了嘛。

   L4145 上帝造题的七分钟2/花神游历各国:https://www.luogu.org/problemnew/show/P4145

   解题思路:先考虑一个性质,一个数经过若干次操作开方后必然成为1或者0(那个数本身就为0),成为1或0之后再怎么开方也不会改变数值了,而这个若干次操作很小,1e9在十余次开方后也是1,所以可以用线段树来维护区间的最值和区间的和,如果区间的最值<=1,那么直接返回就可以了。

#include<bits/stdc++.h>
using namespace std;
#define re register int
#define ll long long
#define INF 0x3f3f3f3f
#define maxn 100009
#define maxm
inline ll read()
{
    ll x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ll)(ch-'0');ch=getchar();}
    return x*f;
}
ll sum[maxn<<2],mx[maxn<<2],val[maxn];
int n,m,k,ans,tot;

#define ls(p) p<<1
#define rs(p) p<<1|1

void push_up(int p)
{
    sum[p]=sum[ls(p)]+sum[rs(p)];
    mx[p]=max(mx[ls(p)],mx[rs(p)]);
}
void built(int p,int l,int r)
{
    if(l==r)
    {
        sum[p]=val[l];
        mx[p]=val[l];
        return ;
    }
    int mid=(l+r)>>1;
    built(ls(p),l,mid);
    built(rs(p),mid+1,r);
    push_up(p);
}

void Update(int p,int l,int r,int nl,int nr)
{
    if(mx[p]<=1)
        return ;
    if(l==r)
    {
        sum[p]=sqrt(sum[p]);
        mx[p]=sum[p];
        return ;
    }
    int mid=(l+r)>>1;
    if(nl<=mid)
        Update(ls(p),l,mid,nl,nr);
    if(mid<nr)
        Update(rs(p),mid+1,r,nl,nr);
    push_up(p);
}

ll Query(int p,int l,int r,int nl,int nr)
{
    ll res=0;
    if(nl<=l&&r<=nr)
        return sum[p];
    int mid=(l+r)>>1;
    if(nl<=mid)
        res+=Query(ls(p),l,mid,nl,nr);
    if(mid<nr)
        res+=Query(rs(p),mid+1,r,nl,nr);
    return res;
}
int main()
{
//    freopen(".in","r",stdin);
//    freopen(".out","w",stdout);
    n=read();
    for(int i=1;i<=n;i++)
        val[i]=read();
    built(1,1,n);
    m=read();
    for(int i=1;i<=m;i++)
    {
        int opt=read(),x=read(),y=read();
        if(x>y)
            swap(x,y);
        if(opt==2)
            Update(1,1,n,x,y);
        else
            printf("%lld
",Query(1,1,n,x,y));
    }
    fclose(stdin);
    fclose(stdout);
    return 0;
}
/*
4
1 100 5 5
5
1 1 2
2 1 2
1 1 2
2 2 3
1 1 4
*/
View Code
原文地址:https://www.cnblogs.com/Dxy0310/p/9751885.html