线段树笔记√

线段树上的每一点表示一段区间和。


建树

首先,递归进去做,递归的参数是pos,l,r,分别表示,线段树上节点的编号(即当前编号),以及这个点表示的区间的左端点和右端点。那么终止的条件就是l=r,这个时候,node[pos].sum=a[l]

我们考虑一下l不等于r的时候,那么这个区间的左儿子就是[l,mid],右儿子就是[mid+1,r]

然后考虑一下左边儿子的节点编号是什么

我们之间令编号为x的节点的左儿子为x*2,右儿子为x*2+1

那么只有对左右儿子递归进去做就行了,做完了之后,我们更新一下sum

【注意】线段树的节点个数要开成序列长度的四倍

 1 struct data{
 2     int sum;
 3 }node[400001];
 4 void build(int pos,int l,int r)
 5 {
 6     if(l==r){
 7         node[pos].sum=a[l];
 8         return;
 9     }
10     int mid=l+r>>1;//µÈ¼Û(l+r)/2,λÔËËãÓÅÏȼ¶×îµÍ
11     int lson=pos*2,rson=pos*2+1;
12     build(lson,l,mid);
13     build(rson,mid+1,r);
14     node[pos].sum=node[pos*2].sum+node[pos*2+1].sum;
15 }
建树
void build(int pos,int l,int r)
{
    ll[pos]=l,rr[pos]=r;
    if(l==r)
    {
        max1[pos]=a[l];
        return;
    }
    int mid=l+r>>1;
    build(pos<<1,l,mid);
    build(pos<<1|1,mid+1,r);
    max1[pos]=MAX(max1[pos<<1],max1[pos<<1|1]);
}
QAQ更简单的建树

查询

比如说,你要查区间[5,6]的和,你会发现,这就是线段树上的一个节点,那直接找到这个节点就行。

//pos表示当前走到了哪个节点,l,r,表示这个节点所代表的区间,ql,qr表示你要查询的区间

我们考虑一下终止条件,是l=ql,r=qr。这个时候我们就可以直接返回node的信息,考虑一下如果不是,有哪些情况,首先我们先算出当前这段区间的中点,考虑一下,如果qr<=mid,那么我们要查询的区间都在左儿子里,这时候直接返回查询左儿子的值就行了。

还有什么情况呢,如果ql>mid,那是不是整个区间都在右儿子里,这时候就直接查右儿子就行了。

那考虑一下,如果查询的区间是一部分在左儿子里,一部分在右儿子里,要怎么办?

我们把查询的区间切开,原本要查[ql,qr],现在变成[ql,mid]+[mid+1,qr],那就是在左儿子里查询[ql,mid],在右儿子里查询[mid+1,qr]

1 int query(int pos,int l,int r,int ql,int qr)
2 {
3     if(l==ql&&r==qr)return node[pos].sum;
4     int mid=l+r>>1,lson=pos*2,rson=pos*2+1;
5     if(qr<=mid)return query(lson,l,mid,ql,qr);
6     else if(ql>mid) return query(rson,mid+1,r,ql,qr);
7     else return           query(lson,l,mid,ql,mid)+query(rson,mid+1,r,mid+1,qr);
8 }
查询
int query(int pos,int l,int r)//sum
{
    l=max(ll[pos],l),r=min(rr[pos],r);
    if(l>r)return 0;//查询max就-inf,min就inf,sum就0 
    if(l==ll[pos]&&r==rr[pos]) return max1[pos];
    int mid=l+r>>1;
    return (query(pos<<1 , l , r)+query(pos<<1|1 , l , r);
}
查询和
int query(int pos,int l,int r)//max
{
    l=max(ll[pos],l),r=min(rr[pos],r);
    if(l>r)return -1*inf;//查询max就-inf,min就inf,sum就0 
    if(l==ll[pos]&&r==rr[pos]) return max1[pos];
    int mid=l+r>>1;
    return max(query(pos<<1 , l , r) , query(pos<<1|1 , l , r));
}
查询最大值
int query(int pos,int l,int r)//min
{
    l=max(ll[pos],l),r=min(rr[pos],r);
    if(l>r)return inf;
    if(l==ll[pos]&&r==rr[pos])return min1[pos];
    int mid=l+r>>1;
    return min(query(pos<<1,l,r),query(pos<<1|1,l,r));
}
查询最小值

 修改

修改有两种情况,单点修改和区间修改。

单点修改:首先还是pos表示当前节点,l,r表示当前节点代表的区间,m表示要修改的位置,显然递归终止的条件是l=r。那考虑一下,其实一个点,要么在mid左边 要么在mid右边,直接判断一下在哪边递归进去做就行了

如果l=r,说明找到这个点了,sum+=v,然后退出

否则的话判断m是不是<=mid

是的话说明在左儿子里,递归下去修改。

不是的话,反之(模仿上述处理)

修改完了之后要注意,要记得更新节点的信息,这个点的和等于左儿子的和加上右儿子的和。

 1 int modify_dot(int pos,int l,int r,int m,int v)
 2 {
 3     if(l==r){
 4         node[pos].sum
 5         return;
 6     }
 7     int mid=l+r>>1,lson=pos*2,rson=pos*2+1;
 8     if(m<=mid)modify_dot(lson,l,mid,m,v);
 9     else modify_dot(rson,mid+1,r,m,v);
10     node[pos].sum=node[lson].sum+node[rson].sum;
11     
12 }
修改
int n,k;
int a[200005];
int ll[maxn],rr[maxn];
int max1[maxn],min1[maxn],sum[maxn];
int lazy[maxn];//下传标记->区间修改用的,初值0 

void downpush(int x)
{
    if(ll[x]==rr[x])return;
    if(lazy[x])
    {
        sum[x<<1]+=lazy[x];
        lazy[x<<1]+=lazy[x];
        sum[x<<1|1]+=lazy[x];
        lazy[x<<1|1]+=lazy[x];
        lazy[x]=0;
    }
    //lazyx表示x被加了lazyx 遍历到这个节点的时候顺便down一下维护他的儿子

}

void modify(int x,int l,int r,int k)//区间修改 当前节点x,l~r区间加k
{
    downpush(x);
    l=max(l,ll[x]),r=min(r,rr[x]);
    if(l>r)return;
    if(l==ll[x]&&r==rr[x])
    {
        lazy[x]=k;
        sum[x]+=k*(r-l+1);
    }
    else modify(x<<1,l,r,k),modify(x<<1|1,l,r,k);
}
区间修改

e.g.

先给你n,m表示序列长度,和操作次数,接下来n个数,表示原序列,接下来m行,表示操作
如果输入格式是1 l r
表示查询l到r的和
如果输入格式是2 l r
表示查询l到r的最小值
如果输入格式是3 l r
表示查询l到r的最大值
如果输入格式是4 m v
表示序列中第m个数变成v

#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std;
long long a[100005];
struct data {
    int sum,max,min;
} node[400001];
void build(int pos,int l,int r) {
    if(l==r) {
        node[pos].sum=a[l];
        node[pos].max=a[l];
        node[pos].min=a[l];
        return;
    }
    int mid=l+r>>1;
    build(pos*2,l,mid);
    build(pos*2+1,mid+1,r);
    node[pos].sum=node[pos*2].sum+node[pos*2+1].sum;
    node[pos].max=node[pos*2].max>node[pos*2+1].max?node[pos*2].max:node[pos*2+1].max;
    node[pos].min=node[pos*2].min>node[pos*2+1].min?node[pos*2+1].min:node[pos*2].min;
}
int query(int pos,int l,int r,int ql,int qr,int x) {
    if (l==ql && r==qr)
        if (x==3)
            return node[pos].max;
        else if (x==2)
            return node[pos].min;
        else if (x==1)
            return node[pos].sum;
        else;
    else {
        int mid=l+r>>1;
        if (qr<=mid)
            return query(pos*2,l,mid,ql,qr,x);
        else if (ql>mid)
            return query(pos*2+1,mid+1,r,ql,qr,x);
        else {
            int a=query(pos*2,l,mid,ql,mid,x),b=query(pos*2+1,mid+1,r,mid+1,qr,x);
            if (x==3)
                return a>b?a:b;
            else if (x==2)
                return a>b?b:a;
            else if (x==1)
                return a+b;
        }
    }
}

void modify_dot(int pos,int v,int n) {
    node[pos].max=v;
    node[pos].min=v;
    node[pos].sum=v;

    while (pos!=1) {
        if (pos%2==0) { //×ó
            node[pos/2].max=node[pos+1].max>node[pos].max?node[pos+1].max:node[pos].max;
            node[pos/2].min=node[pos+1].min<node[pos].min?node[pos+1].min:node[pos].min;
            node[pos/2].sum=node[pos].sum+node[pos+1].sum;
        } else { //ÓÒ
            node[pos/2].max=node[pos-1].max>node[pos].max?node[pos-1].max:node[pos].max;
            node[pos/2].min=node[pos-1].min<node[pos].min?node[pos-1].min:node[pos].min;
            node[pos/2].sum=node[pos].sum+node[pos-1].sum;
        }
        pos=pos/2;//°Ö°Ö
    }
    return;
}
int getpos(int pos,int l,int r,int m) {
    if (l==r && l==m)
        return pos;
    int mid=l+r>>1;
    if (m<=mid)
        return getpos(pos*2,l,mid,m);
    else
        return getpos(pos*2+1,mid+1,r,m);
}
int main() {
    int n,m;
    memset(a,0,sizeof(a));
    scanf("%d%d",&n,&m);
    for(int i=1; i<=n; i++)scanf("%d",&a[i]);
    build(1,1,n);
/*    for (int j=1; j<=5; ++j)
        printf("%d µÄmax:%d min:%d sum:%d
",j,node[j].max,node[j].min,node[j].sum);
    printf("
");
*/
    for(int i=1; i<=m; i++) {
        int x,ql,qr;
        scanf("%d%d%d",&x,&ql,&qr);
        if(x==1) {
            printf("%d
",query(1,1,n,ql,qr,1));
        }
        if(x==2) {
            printf("%d
",query(1,1,n,ql,qr,2));
        }
        if(x==3) {
            printf("%d
",query(1,1,n,ql,qr,3));
        }
        if(x==4) {
            int pos=getpos(1,1,n,ql);
            modify_dot(pos,qr,n);
        }
    }
    return 0;
}
非常复杂的奇怪解法
原文地址:https://www.cnblogs.com/gc812/p/5773903.html