ZJOI2008]树的统计(树链剖分,线段树)

题目描述

一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。

我们将以下面的形式来要求你对这棵树完成一些操作:

I. CHANGE u t : 把结点u的权值改为t

II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值

III. QSUM u v: 询问从点u到点v的路径上的节点的权值和

注意:从点u到点v的路径上的节点包括u和v本身

输入输出格式

输入格式:

 

输入文件的第一行为一个整数n,表示节点的个数。

接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。

接下来一行n个整数,第i个整数wi表示节点i的权值。

接下来1行,为一个整数q,表示操作的总数。

接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。

输出格式:

对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果

思路:

树剖板子题

将树剖好后跑线段树查询即可

代码:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define rii register int i
#define rij register int j
#define int long long 
using namespace std;
int n,head[200005],size[200005],f[200005],zs[200005],bnt;
int top[200005],nid[200005],nsd[200005],cnt,val[200005];
int nval[200005],q;
struct ljb{
    int to,nxt;
}y[400005];
struct xds{
    int maxn,sum;
}x[1000005];
inline void add(int from,int to)
{
    bnt++;
    y[bnt].to=to;
    y[bnt].nxt=head[from];
    head[from]=bnt;
}
void dfs1(int wz,int fa,int sd)
{
    f[wz]=fa;
    nsd[wz]=sd;
    size[wz]=1;
    int maxn=0;
    for(rii=head[wz];i!=0;i=y[i].nxt)
    {
        int to=y[i].to;
        if(to!=fa)
        {
            dfs1(to,wz,sd+1);
            size[wz]+=size[to];
            if(size[to]>maxn)
            {
                zs[wz]=to;
                maxn=size[to];
            }
        }
    }
}
void dfs2(int wz,int ntop)
{
    cnt++;
    nid[wz]=cnt;
    nval[cnt]=val[wz];
    top[wz]=ntop;
    if(zs[wz]==0)
    {
        return;
    }
    dfs2(zs[wz],ntop);
    for(rii=head[wz];i!=0;i=y[i].nxt)
    {
        int to=y[i].to;
        if(zs[wz]!=to&&f[wz]!=to)
        {
            dfs2(to,to);
        }
    }
}
void build(int l,int r,int bh)
{
    if(l==r)
    {
        x[bh].sum=nval[l];
        x[bh].maxn=nval[l];
        return;
    }
    int mid=(l+r)/2;
    build(l,mid,bh*2);
    build(mid+1,r,bh*2+1);
    x[bh].sum=x[bh*2].sum+x[bh*2+1].sum;
    x[bh].maxn=max(x[bh*2].maxn,x[bh*2+1].maxn);
}
void change(int wz,int nl,int nr,int val,int bh)
{
    if(nl==nr&&nl==wz)
    {
        x[bh].maxn=val;
        x[bh].sum=val;
        return;
    }
    int mid=(nl+nr)/2;
    if(wz<=mid)
    {
        change(wz,nl,mid,val,bh*2);
    }
    else
    {
        change(wz,mid+1,nr,val,bh*2+1);
    }
    x[bh].maxn=max(x[bh*2].maxn,x[bh*2+1].maxn);
    x[bh].sum=x[bh*2].sum+x[bh*2+1].sum; 
}
int querym(int l,int r,int nl,int nr,int bh)
{
    if(l<nl)
    {
        l=nl;
    }
    if(r>nr)
    {
        r=nr;
    }
    if(l==nl&&r==nr)
    {
        return x[bh].maxn;
    }
    int mid=(nl+nr)/2;
    int val=-500000;
    if(l<=mid)
    {
        int a1=querym(l,r,nl,mid,bh*2);
        val=max(a1,val);
    }
    if(r>mid)
    {
        int a2=querym(l,r,mid+1,nr,bh*2+1);
        val=max(val,a2);
    }
    return val;
}
int qmax(int from,int to)
{
    int ans=-500000;
    while(top[from]!=top[to])
    {
        if(nsd[top[from]]<nsd[top[to]])
        {
            swap(from,to);
        }
        int res=0;
        res=querym(nid[top[from]],nid[from],1,n,1);
        ans=max(ans,res);
        from=f[top[from]];
    }
    if(nsd[from]>nsd[to])
    {
        swap(from,to);
    }
    int res=0;
    res=querym(nid[from],nid[to],1,n,1);
    ans=max(ans,res);
    return ans;
}
int querys(int l,int r,int nl,int nr,int bh)
{
    if(l<nl)
    {
        l=nl;
    }
    if(r>nr)
    {
        r=nr;
    }
    if(l==nl&&r==nr)
    {
        return x[bh].sum;
    }
    int mid=(nl+nr)/2;
    int val=0;
    if(l<=mid)
    {
        val+=querys(l,r,nl,mid,bh*2);
    }
    if(r>mid)
    {
        val+=querys(l,r,mid+1,nr,bh*2+1);
    }
    return val;
}
int qsum(int from,int to)
{
    int ans=0;
    while(top[from]!=top[to])
    {
        if(nsd[top[from]]<nsd[top[to]])
        {
            swap(from,to);
        }
        int res=0;
        res=querys(nid[top[from]],nid[from],1,n,1);
        ans+=res;
        from=f[top[from]];
    }
    if(nsd[from]>nsd[to])
    {
        swap(from,to);
    }
    int res=0;
    res=querys(nid[from],nid[to],1,n,1);
    ans+=res;
    return ans;
}
signed main()
{
    for(rii=1;i<=400000;i++)
    {
        x[i].maxn=-500000;
    }
    scanf("%lld",&n);
    for(rii=1;i<=n-1;i++)
    {
        int from,to;
        scanf("%lld%lld",&from,&to);
        add(from,to);
        add(to,from);
    }
    dfs1(1,1,0);
    for(rii=1;i<=n;i++)
    {
        scanf("%lld",&val[i]);
    }
    dfs2(1,1);
    build(1,n,1);
    scanf("%lld",&q);
    for(rii=1;i<=q;i++)
    {
        int from,to,val;
        string s; 
        char c=getchar();
        while(c<'A'||c>'Z')
        {
            c=getchar();
        }
        while(c>='A'&&c<='Z')
        {
            s+=c;
            c=getchar();
        }
        if(s=="CHANGE")
        {
            scanf("%lld%lld",&from,&val);
            change(nid[from],1,n,val,1);
        }
        if(s=="QMAX")
        {
            scanf("%lld%lld",&from,&to);
            int ltt=qmax(from,to);
            printf("%lld
",ltt);
        }
        if(s=="QSUM")
        {
            scanf("%lld%lld",&from,&to);
            int ltt=qsum(from,to);
            printf("%lld
",ltt);
        }
    }
}
原文地址:https://www.cnblogs.com/ztz11/p/9904868.html