树的统计

题目

这是一道经典树链剖分

维护区间和,区间最大值

#include<cstdio>
#include<algorithm>
using namespace std;
int a[30005],top[30005],rev[120005],seg[30005],father[30005],sum[120005],n,k1;
int summ,maxx,size[30005],d[30005],h[30005],son[30005],cnt,tot=0,m[120005];
struct edge{
    int to,next;
}e[120005];
void add(int x,int y)
{
    e[++tot].to=y;
    e[tot].next=h[x];
    h[x]=tot;
}
void dfs1(int u,int fa)
{
    father[u]=fa;
    d[u]=d[fa]+1;
    size[u]=1;
    for (int i=h[u];i!=0;i=e[i].next)
    {
        int v=e[i].to;
        if (v==fa) continue;
        dfs1(v,u);
        size[u]+=size[v];
        if (size[v]>size[son[u]]) son[u]=v;
    }
}
void dfs2(int u,int fa)
{
    if (son[u])
    {
        seg[son[u]]=++cnt;
        top[son[u]]=top[u];
        rev[cnt]=son[u];
        dfs2(son[u],u);
    }
    for (int i=h[u];i!=0;i=e[i].next)
    {
        int v=e[i].to;
        if (!top[v])
        {
            seg[v]=++cnt;
            rev[cnt]=v;
            top[v]=v;
            dfs2(v,u);
        }
    }
}
void build(int k,int l,int r)//建树
{
    int mid=(l+r)>>1;
    if (l==r)
    {
        m[k]=sum[k]=a[rev[l]];
        return;
    }
    build(k*2,l,mid);
    build(k*2+1,mid+1,r);
    sum[k]+=sum[k*2]+sum[k*2+1];
    m[k]=max(m[k*2],m[k*2+1]);
}
void change(int k,int l,int r,int v,int pos)//单点修改
{
    if (pos<l||pos>r) return;
    if (l==r&&r==pos)
    {
        sum[k]=m[k]=v;
        return;
    }
    int mid=l+r>>1;
    if (mid>=pos) change(k*2,l,mid,v,pos);
    if (mid+1<=pos) change(k*2+1,mid+1,r,v,pos);
    sum[k]=sum[k*2]+sum[k*2+1];
    m[k]=max(m[k*2],m[k*2+1]);
}
void query(int k,int l,int r,int x,int y)//区间查询
{
    if (y<l||x>r) return;
    if (x<=l&&r<=y) 
    {
        summ+=sum[k];
        maxx=max(maxx,m[k]);
        return;
    }
    int mid=l+r>>1;
    if (mid>=x) query(k*2,l,mid,x,y);
    if (mid+1<=y) query(k*2+1,mid+1,r,x,y);
}
void ask(int l,int r)
{
    int fl=top[l],fr=top[r];
    while (fl!=fr)
    {
        if (d[fl]<d[fr]) swap(l,r),swap(fl,fr);
        query(1,1,cnt,seg[fl],seg[l]);
        l=father[fl],fl=top[l];
    }
    if (d[l]>d[r]) swap(l,r);
    query(1,1,cnt,seg[l],seg[r]);
}
int main()
{
    tot=0;
    scanf("%d",&n);
    for (int i=1;i<n;i++)
    {
        int q,p;
        scanf("%d%d",&q,&p);
        add(q,p),add(p,q);
    }
    for (int i=1;i<=n;i++)
        scanf("%d",&a[i]);
    dfs1(1,0);
    cnt=seg[1]=top[1]=rev[1]=1;
    dfs2(1,0);
    build(1,1,cnt);
    char ch[10];
    scanf("%d",&k1);
    for (int i=1;i<=k1;i++)
    {
        scanf("%s",ch+1);    
        int q,p;
        scanf("%d%d",&q,&p);
        if(ch[2]=='H') change(1,1,cnt,p,seg[q]);
        else
        {
            summ=0;
            maxx=-2147483647;
            ask(q,p);
            if (ch[2]=='M') printf("%d
",maxx);
            else printf("%d
",summ);
        }
    }
}
原文地址:https://www.cnblogs.com/nibabadeboke/p/11333074.html