洛谷P1600 天天爱跑步——树上差分

题目:https://www.luogu.org/problemnew/show/P1600

看博客:https://blog.csdn.net/clove_unique/article/details/53427248

思路好神啊...

树上差分是好东西。

代码如下:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
int const maxn=300005;
int n,m,head[maxn],ct,w[maxn],tp,fp,ans[maxn],f[maxn][20],h[maxn],dfn[maxn],tim,a[maxn<<1];
struct P{int t,pt,val;}tor[maxn<<2],fr[maxn<<2];
struct N{
    int to,next;
    N(int t=0,int n=0):to(t),next(n) {}
}edge[maxn<<1];
void add(int x,int y){edge[++ct]=N(y,head[x]); head[x]=ct;}
bool cmp(P x,P y){return dfn[x.pt]<dfn[y.pt];}
void init(int x,int fa)
{
    h[x]=h[fa]+1; f[x][0]=fa; dfn[x]=++tim;
    for(int i=1;i<=18;i++)
        f[x][i]=f[f[x][i-1]][i-1];
    for(int i=head[x],u;i;i=edge[i].next)
        if((u=edge[i].to)!=fa)init(u,x);
}
int lca(int x,int y)
{
    if(h[x]<h[y])swap(x,y);
    int k=h[x]-h[y];
    for(int i=18;i>=0;i--)
        if(k&(1<<i))x=f[x][i];
    for(int i=18;i>=0;i--)
        if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];
    if(x==y)return x;
    return f[x][0];
}
void dfs1(int x,int f)
{
    int val=w[x]+h[x],dec=a[val];
    while(tp<=(m<<1) && tor[tp].pt==x) a[tor[tp].t + h[x]]+=tor[tp].val, tp++;
    for(int i=head[x],u;i;i=edge[i].next)
        if((u=edge[i].to)!=f)dfs1(u,x);
    ans[x]+=a[val]-dec;
}
void dfs2(int x,int f)
{
    int val=w[x]-h[x]+1,dec=a[val];
    while(fp<=(m<<1) && fr[fp].pt==x) a[fr[fp].t]+=fr[fp].val, fp++;
    for(int i=head[x],u;i;i=edge[i].next)
        if((u=edge[i].to)!=f)dfs2(u,x);
    ans[x]+=a[val]-dec;
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1,x,y;i<n;i++)
    {
        scanf("%d%d",&x,&y);
        add(x,y); add(y,x);
    }
    init(1,0);
    for(int i=1;i<=n;i++)scanf("%d",&w[i]);
    for(int i=1,s,t;i<=m;i++)
    {
        scanf("%d%d",&s,&t);
        int r=lca(s,t);
        tor[++tp].pt=s; tor[tp].t=0; tor[tp].val=1;
        if(r!=1)tor[++tp].pt=f[r][0], tor[tp].t=h[s]-h[r]+1, tor[tp].val=-1;
        fr[++fp].pt=t; fr[fp].t=h[s]-h[r]-h[r]+1; fr[fp].val=1;
        fr[++fp].pt=r; fr[fp].t=h[s]-h[r]-h[r]+1; fr[fp].val=-1;
    }
    
    sort(tor+1,tor+tp+1,cmp); tp=1;
    sort(fr+1,fr+fp+1,cmp); fp=1;
    dfs1(1,0);
    memset(a,0,sizeof a);
    dfs2(1,0);
    for(int i=1;i<=n;i++)printf("%d ",ans[i]);
    return 0;
}
原文地址:https://www.cnblogs.com/Zinn/p/9216874.html