树链剖分 洛谷3038

题目链接:https://www.luogu.org/problemnew/show/P3038

一开始以为是简单的查询和修改树上面两个点之间的路径,直接打了一遍板子,发现样例都过不了,后来才发现这里的权值是边权,不是点权了,之前只用树链剖分对树上两点之间点权进行查询和修改,现在要我们对边权进行操作了,那么怎么计算点权呢?

后来想到可以用点权来代表边权,联想到树的性质,除了根节点之外其他的所有点都只有一条树边连接这个点和它的父节点,那么对于每一条树边,我们都可以用它相邻两点之间深度更大的那个点来表示。

如果现在我们对两个点x,y之间的路径进行查询或者修改操作,那么我们肯定是要先判断在我们划分的链(好像叫重链和轻链)里面这两个点是不是处于同一条链,如果在同一条链上面的话就要把L[x](L[x]代表点x的编号)加1,因为编号为L[x]的点代表的是它头顶的那条树边,这条树边不在我们要查询的路径上面,所以我们要对L[x]加上1,如果不在同一条链上面就不用加,因为top[x]所代表的边也是在路径上面的,其他的就是板子了。

注意,我最后一个样例一直超时,后来发现要用那个read函数来节省时间,不然好像会一直TLE。

其实这道题对我们可以帮助我们理解树链剖分。

代码:

#include<iostream>
#include<cstring>
#include<algorithm>
#include<queue>
#include<map>
#include<stack>
#include<cmath>
#include<vector>
#include<set>
#include<cstdio>
#include<string>
#include<deque> 
using namespace std;
typedef long long LL;
#define eps 1e-8
#define INF 0x3f3f3f3f
#define maxn 200005
int head[maxn],tot[maxn],L[maxn],id[maxn];
//L[x]代表点x的编号,id[Time]表示编号为Time的点 
int top[maxn],son[maxn],fa[maxn],dep[maxn];
//top[x]表示点x所在的这条链上面深度最下的点
//son[x]表示x的重儿子 
int n,m,k,t;
int cnt,Time;//这里cnt用于链式前向星,Time来给点编号计数 
char op[10];
int x,y;
struct Edge{
    int next,v;
}edge[maxn<<1];
struct node{
    int l,r,w,f;
}tree[maxn<<2];
void init(){
    memset(head,-1,sizeof(head));
    memset(top,0,sizeof(top));
    memset(son,0,sizeof(son));
    memset(tot,0,sizeof(tot));
    cnt=Time=0;
}
void add(int u,int v){
    edge[++cnt].v=v;
    edge[cnt].next=head[u];
    head[u]=cnt;
}
void DFS(int u,int pre,int depth){//第一次DFS求出深度和每个点的重儿子,父亲 
    dep[u]=depth;
    fa[u]=pre;
    int maxx=-1;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].v;
        if(v==pre)
        continue;
        DFS(v,u,depth+1);
        tot[u]+=tot[v];
        if(maxx<tot[v]){
            maxx=tot[v];
            son[u]=v;
        }
    }
}
void DFS2(int u,int pre){//第二次DFS用来及计算出top数组和对点编号 
    top[u]=pre;
    id[++Time]=u;
    L[u]=Time;
    if(son[u]==0)
    return;
    DFS2(son[u],pre);
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].v;
        if(top[v]==0)
        DFS2(v,v);
    }
}
void update(int k){
    tree[k].w=tree[k<<1].w+tree[k<<1|1].w;
}
void build(int l,int r,int k){
    tree[k].l=l;
    tree[k].r=r;
    tree[k].f=0;
    tree[k].w=0;
    if(l==r){
        return;
    }
    int mid=(l+r)/2;
    build(l,mid,k<<1);
    build(mid+1,r,k<<1|1);
    update(k);
}
void down(int k){
    tree[k<<1].w+=(tree[k<<1].r-tree[k<<1].l+1)*tree[k].f;
    tree[k<<1|1].w+=(tree[k<<1|1].r-tree[k<<1|1].l+1)*tree[k].f;
    tree[k<<1].f+=tree[k].f;
    tree[k<<1|1].f+=tree[k].f;
    tree[k].f=0;
}
void change_interval(int L,int R,int k){//区间修改 
    if(L>R)//这里面要特判一下,因为最后x和y可能代表的是同一个点,导致L>R 
    return;
    if(tree[k].l>=L&&tree[k].r<=R){
        tree[k].w+=tree[k].r-tree[k].l+1;
        tree[k].f++;
        return;
    }
    if(tree[k].f)
    down(k);
    int mid=(tree[k].l+tree[k].r)/2;
    if(L<=mid)
    change_interval(L,R,k<<1);
    if(R>mid)
    change_interval(L,R,k<<1|1);
    update(k);
}
int ask_interval(int L,int R,int k){//区间查询 
    if(L>R)//特判 
    return 0;
    if(tree[k].l>=L&&tree[k].r<=R){
        return tree[k].w;
    }
    if(tree[k].f)
    down(k);
    int mid=(tree[k].l+tree[k].r)/2;
    int ans=0;
    if(L<=mid)
    ans+=ask_interval(L,R,k<<1);
    if(R>mid)
    ans+=ask_interval(L,R,k<<1|1);
    update(k);
    return ans;
}
void change(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        change_interval(L[top[x]],L[x],1);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    change_interval(L[x]+1,L[y],1);//这里L[x]+1 
    //如果x和y是同一个点的时候是要特判的,我这里是在区间查询和修改的部分特判的 
    return;
}
int ask(int x,int y){
    int ans=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        ans+=ask_interval(L[top[x]],L[x],1);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    ans+=ask_interval(L[x]+1,L[y],1);
    return ans;
}
inline int read()
{
    char c=getchar();int num=0;
    for(;!isdigit(c);c=getchar())
        if(c=='P') return 1;
        else if(c=='Q') return 2;
    for(;isdigit(c);c=getchar())
        num=num*10+c-'0';
    return num;
}
int main()
{
        n=read(),m=read();
        init();
        int u,v;
        for(int i=0;i<n-1;i++){
            u=read(),v=read();
            add(u,v);
            add(v,u);
        }
        DFS(1,1,1);//u,pre,depth;
        DFS2(1,1);//u,pre
        build(1,n,1);
        for(int i=1;i<=m;i++){
            scanf("%s",op);
            x=read(),y=read();
            if(op[0]=='P'){
                change(x,y);
            }else{
                printf("%d
",ask(x,y));
            }
        }
    return 0;
}
原文地址:https://www.cnblogs.com/6262369sss/p/10678917.html