[ZJOI2008]树的统计(树链剖分)

原题

洛谷
BZOJ

Solution

这道题目不是看到就发现是一道树链剖分的裸题吗?

#include<stdio.h>
#include<stdlib.h>
#define ll long long
ll max(ll a,ll b){
    if(a>b)return a;
    return b;
}
void swap(int &a,int &b){
    int tmp=a;a=b;b=tmp;
}
int gi(){
    int sum=0,f=1;char ch=getchar();
    while(ch>'9' || ch<'0'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0' && ch<='9'){sum=sum*10+ch-'0';ch=getchar();}
    return f*sum;
}
ll gl(){
    ll sum=0,f=1;char ch=getchar();
    while(ch>'9' || ch<'0'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0' && ch<='9'){sum=sum*10+ch-'0';ch=getchar();}
    return f*sum;
}
const int maxn=100010;
struct node{
    int to,nxt;
}e[maxn*2];
int cnt,front[maxn],root,son[maxn],dep[maxn],fa[maxn],siz[maxn],id[maxn],top[maxn],num;
ll b[maxn],w[maxn];
void Add(int u,int v){
    e[++cnt].to=v;e[cnt].nxt=front[u];front[u]=cnt;
}
void dfs1(int u,int f,int d){
    fa[u]=f;dep[u]=d;siz[u]=1;
    for(int i=front[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(v!=f){
            dfs1(v,u,d+1);
            siz[u]+=siz[v];
            if(!son[u] || siz[son[u]]<siz[v])son[u]=v;
        }
    }
}
void dfs2(int u,int f){
    top[u]=f;id[u]=++num;b[num]=w[u];
    if(!son[u])return;
    dfs2(son[u],f);
    for(int i=front[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(v!=fa[u] && v!=son[u])dfs2(v,v);
    }
}
struct tree{
    ll max,val;
}t[4*maxn];
#define ls o<<1
#define rs o<<1|1
void pushup(int o){
    t[o].val=t[ls].val+t[rs].val;
    t[o].max=max(t[ls].max,t[rs].max);
}
void build(int o,int l,int r){
    if(l==r){
        t[o].val=t[o].max=b[l];return;
    }
    int mid=(l+r)>>1;
    build(ls,l,mid);build(rs,mid+1,r);
    pushup(o);
}
void update(int o,int l,int r,int pos,ll k){
    if(l==r){
        t[o].val=k;t[o].max=k;return;
    }
    int mid=(l+r)>>1;
    if(pos<=mid)update(ls,l,mid,pos,k);
    else update(rs,mid+1,r,pos,k);
    pushup(o);
}
ll query1(int o,int l,int r,int posl,int posr){
    if(posl<=l && r<=posr)return t[o].val;
    int mid=(l+r)>>1;
    if(mid>=posr)return query1(ls,l,mid,posl,posr);
    if(mid<posl)return query1(rs,mid+1,r,posl,posr);
    return query1(ls,l,mid,posl,mid)+query1(rs,mid+1,r,mid+1,posr);
}
ll query2(int o,int l,int r,int posl,int posr){
    if(posl<=l && r<=posr)return t[o].max;
    int mid=(l+r)>>1;
    if(mid>=posr)return query2(ls,l,mid,posl,posr);
    if(mid<posl)return query2(rs,mid+1,r,posl,posr);
    return max(query2(ls,l,mid,posl,mid),query2(rs,mid+1,r,mid+1,posr));
}
ll sum(int x,int y){
    ll ans=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        ans+=query1(1,1,num,id[top[x]],id[x]);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])swap(x,y);
    ans+=query1(1,1,num,id[x],id[y]);
    return ans;
}
ll big(int x,int y){
    ll ans=-30000;
    int s=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        ans=max(ans,query2(1,1,num,id[top[x]],id[x]));
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])swap(x,y);
    ans=max(ans,query2(1,1,num,id[x],id[y]));
    return ans;
}
int main(){
    int i,j,k,n,m;
    n=gi();
    for(i=1;i<n;i++){
        int u=gi(),v=gi();
        Add(u,v);Add(v,u);
    }
    for(i=1;i<=n;i++)w[i]=gl();
    root=1;
    dfs1(root,0,1);#include<stdio.h>
#include<stdlib.h>
#define ll long long
ll max(ll a,ll b){
    if(a>b)return a;
    return b;
}
void swap(int &a,int &b){
    int tmp=a;a=b;b=tmp;
}
int gi(){
    int sum=0,f=1;char ch=getchar();
    while(ch>'9' || ch<'0'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0' && ch<='9'){sum=sum*10+ch-'0';ch=getchar();}
    return f*sum;
}
ll gl(){
    ll sum=0,f=1;char ch=getchar();
    while(ch>'9' || ch<'0'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0' && ch<='9'){sum=sum*10+ch-'0';ch=getchar();}
    return f*sum;
}
const int maxn=100010;
struct node{
    int to,nxt;
}e[maxn*2];
int cnt,front[maxn],root,son[maxn],dep[maxn],fa[maxn],siz[maxn],id[maxn],top[maxn],num;
ll b[maxn],w[maxn];
void Add(int u,int v){
    e[++cnt].to=v;e[cnt].nxt=front[u];front[u]=cnt;
}
void dfs1(int u,int f,int d){
    fa[u]=f;dep[u]=d;siz[u]=1;
    for(int i=front[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(v!=f){
            dfs1(v,u,d+1);
            siz[u]+=siz[v];
            if(!son[u] || siz[son[u]]<siz[v])son[u]=v;
        }
    }
}
void dfs2(int u,int f){
    top[u]=f;id[u]=++num;b[num]=w[u];
    if(!son[u])return;
    dfs2(son[u],f);
    for(int i=front[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(v!=fa[u] && v!=son[u])dfs2(v,v);
    }
}
struct tree{
    ll max,val;
}t[4*maxn];
#define ls o<<1
#define rs o<<1|1
void pushup(int o){
    t[o].val=t[ls].val+t[rs].val;
    t[o].max=max(t[ls].max,t[rs].max);
}
void build(int o,int l,int r){
    if(l==r){
        t[o].val=t[o].max=b[l];return;
    }
    int mid=(l+r)>>1;
    build(ls,l,mid);build(rs,mid+1,r);
    pushup(o);
}
void update(int o,int l,int r,int pos,ll k){
    if(l==r){
        t[o].val=k;t[o].max=k;return;
    }
    int mid=(l+r)>>1;
    if(pos<=mid)update(ls,l,mid,pos,k);
    else update(rs,mid+1,r,pos,k);
    pushup(o);
}
ll query1(int o,int l,int r,int posl,int posr){
    if(posl<=l && r<=posr)return t[o].val;
    int mid=(l+r)>>1;
    if(mid>=posr)return query1(ls,l,mid,posl,posr);
    if(mid<posl)return query1(rs,mid+1,r,posl,posr);
    return query1(ls,l,mid,posl,mid)+query1(rs,mid+1,r,mid+1,posr);
}
ll query2(int o,int l,int r,int posl,int posr){
    if(posl<=l && r<=posr)return t[o].max;
    int mid=(l+r)>>1;
    if(mid>=posr)return query2(ls,l,mid,posl,posr);
    if(mid<posl)return query2(rs,mid+1,r,posl,posr);
    return max(query2(ls,l,mid,posl,mid),query2(rs,mid+1,r,mid+1,posr));
}
ll sum(int x,int y){
    ll ans=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        ans+=query1(1,1,num,id[top[x]],id[x]);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])swap(x,y);
    ans+=query1(1,1,num,id[x],id[y]);
    return ans;
}
ll big(int x,int y){
    ll ans=-30000;
    int s=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        ans=max(ans,query2(1,1,num,id[top[x]],id[x]));
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])swap(x,y);
    ans=max(ans,query2(1,1,num,id[x],id[y]));
    return ans;
}
int main(){
    int i,j,k,n,m;
    n=gi();
    for(i=1;i<n;i++){
        int u=gi(),v=gi();
        Add(u,v);Add(v,u);
    }
    for(i=1;i<=n;i++)w[i]=gl();
    root=1;
    dfs1(root,0,1);
    dfs2(root,root);
    build(1,1,n);

    scanf("%d",&m);
    for(i=1;i<=m;i++){
        char op[10];scanf("%s",op);
        if(op[0]=='C'){
            int u;ll t;scanf("%d%lld",&u,&t);
            update(1,1,n,id[u],t);
        }
        else if(op[1]=='S'){
            int u,v;scanf("%d%d",&u,&v);
            printf("%lld
",sum(u,v));
        }
        else{
            int u,v;scanf("%d%d",&u,&v);
            printf("%lld
",big(u,v));
        }
    }
    return 0;
}

    dfs2(root,root);
    build(1,1,n);

    scanf("%d",&m);
    for(i=1;i<=m;i++){
        char op[10];scanf("%s",op);
        if(op[0]=='C'){
            int u;ll t;scanf("%d%lld",&u,&t);
            update(1,1,n,id[u],t);
        }
        else if(op[1]=='S'){
            int u,v;scanf("%d%d",&u,&v);
            printf("%lld
",sum(u,v));
        }
        else{
            int u,v;scanf("%d%d",&u,&v);
            printf("%lld
",big(u,v));
        }
    }
    return 0;
}
原文地址:https://www.cnblogs.com/cjgjh/p/9833514.html