[ZJOI2008] 树的统计(树链剖分)

题目描述

一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。

我们将以下面的形式来要求你对这棵树完成一些操作:

I. CHANGE u t : 把结点u的权值改为t

II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值

III. QSUM u v: 询问从点u到点v的路径上的节点的权值和

注意:从点u到点v的路径上的节点包括u和v本身

输入输出格式

输入格式:

输入文件的第一行为一个整数n,表示节点的个数。

接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。

接下来一行n个整数,第i个整数wi表示节点i的权值。

接下来1行,为一个整数q,表示操作的总数。

接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。


输出格式:

对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。

输入输出样例

输入样例#1:

4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4

输出样例#1:

4
1
2
2
10
6
5
6
5
16

说明

对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。

Solution

其实就是个板子,但是好坑啊...
我的query函数如果说不在查询范围内的话我就返回0.
然后导致我错的一塌糊涂...

代码

#include<bits/stdc++.h>
using namespace std;
#define ll long long
const ll maxn=100008;
const ll inf=19260817000300;
struct sj{ll to,next;}a[maxn*2];
ll dep[maxn],fa[maxn];
ll size,n,m,num,pd;
ll col[maxn],c[maxn];
ll head[maxn],id[maxn];
ll siz[maxn],zu[maxn],son[maxn];
ll sgm[maxn*4],sgx[maxn*4];
void add(ll x,ll y)
{
    a[++size].to=y;
    a[size].next=head[x];
    head[x]=size;
}

void dfs1(ll x)
{
    siz[x]=1;
    for(ll i=head[x];i;i=a[i].next)
    {
        ll tt=a[i].to;
        if(!siz[tt])
        {
            dep[tt]=dep[x]+1;
            fa[tt]=x;
            dfs1(tt);
            siz[x]+=siz[tt];
            if(siz[tt]>siz[son[x]])
            son[x]=tt;
        }
    }
}

void dfs2(ll x,ll anc)
{   
    zu[x]=anc; c[++num]=col[x]; id[x]=num;
    if(son[x]) dfs2(son[x],anc);
    for(ll i=head[x];i;i=a[i].next)
    {
        ll tt=a[i].to;
        if(!zu[tt])
            if(tt==son[x])
                continue;
            else dfs2(tt,tt);
    }
}

void update(ll node)
{
    sgm[node]=sgm[node*2]+sgm[node*2+1];
    sgx[node]=max(sgx[node*2],sgx[node*2+1]);
}

void build(ll node,ll l,ll r)
{
    if(l==r)
    {
        sgm[node]=c[l];
        sgx[node]=c[l];
        return;
    }
    ll mid=(l+r)/2;
    build(node*2,l,mid);
    build(node*2+1,mid+1,r);
    update(node);
}

void change(ll node,ll l,ll r,ll L,ll R,ll cc)
{
    if(l>R||L>r)return;
    if(l>=L&&r<=R)
    {
        sgm[node]=cc;
        sgx[node]=cc;
        return; 
    }
    ll mid=(l+r)/2;
    change(node*2,l,mid,L,R,cc);
    change(node*2+1,mid+1,r,L,R,cc);
    update(node);
}

ll query(ll node,ll l,ll r,ll L,ll R)
{
    if(l>R||L>r)return -inf;
    if(l>=L&&r<=R)
    {
        if(pd==1)
        return sgm[node];
        else return sgx[node];
    }
    ll mid=(l+r)/2;
    ll llx=query(node*2,l,mid,L,R);
    ll rrx=query(node*2+1,mid+1,r,L,R);
    if(pd) 
	{
		if(llx==-inf)llx=0;
		if(rrx==-inf)rrx=0;
		return llx+rrx;
	}
	return max(llx,rrx);
}

ll check(ll x,ll y)
{
    ll rest=0,rest1=-inf;
    while(zu[x]!=zu[y])
    {
        if(dep[zu[x]]<dep[zu[y]])
        swap(x,y);
        ll kk=query(1,1,n,id[zu[x]],id[x]);
        if(pd)
        rest+=kk;
       	else rest1=max(rest1,kk);
        x=fa[zu[x]];
    }
    if(id[x]>id[y])
    swap(x,y);
    
    ll tt=query(1,1,n,id[x],id[y]);
    if(pd){rest+=tt;cout<<rest<<endl;return 0;}else rest1=max(tt,rest1);
    cout<<rest1<<endl;
}
ll x,y,z;
int main()
{
    scanf("%lld",&n);
    for(ll i=1;i<n;i++)
    {
        scanf("%lld%lld",&x,&y);
        add(x,y); add(y,x);
    }
    for(ll i=1;i<=n*4;i++)
    sgx[i]=-inf;
    for(ll i=1;i<=n;i++)
    scanf("%lld",&col[i]);
    dfs1(1);
    dfs2(1,1);
    build(1,1,n);
    scanf("%lld",&m);
    while(m--)
    {
        char ch[10];pd=0; 
       	scanf("%s",ch);scanf("%lld%lld",&x,&y);
       	if(ch[1]=='M')
       		check(x,y);
        else if(ch[1]=='S')
            pd=1,check(x,y);
        else 
            change(1,1,n,id[x],id[x],y);
    }
}
原文地址:https://www.cnblogs.com/Kv-Stalin/p/9247059.html