天天爱跑步 线段树合并

天天爱跑步 线段树合并

使用线段树合并做法

有路径(u_i->lca(u_i,v_i)->v_i),将路径分为两半分开讨论。

先考虑(u_i->lca(u_i,v_i))前一半路径:

对节点(x)有贡献,当且仅当(dep[u_i]-dep[x]=w[x]),移项使含(x)的在一边使统计更方便(dep[u_i]=w[x]+dep[x]),这样只需统计节点(x)子树内路径起点深度为(w[x]+dep[x])的个数。

(lca->v_i)后一半路径同理:

对节点(x)有贡献,当且仅当(dep[v_i]-dep[x]=dis-w[x]),其中(dis)(u,v)路径长度,同样的,为统计方便,运用参变分离思想,我们移项得(dep[v_i]-dis=dep[x]-w[x]),然后又可以把(dis)拆开,得到(-dep[u_i]+2*dep[lca]=dep[x]-w[x])这样只需要统计节点(x)子树内那些路径终点(v_i)满足(dep[u_i]-2*dep[lca])(w[x]-dep[x])的个数即可。

然后考虑如何统计。暴力枚举子树跟统计子树答案肯定炸,可以用桶或者线段树合并或者DSU on Tree统计。

这里使用线段树合并,因为上瘾了可以直接套上树上差分的板子不需要各种讨论所以思维难度很小。

两条路径各开一组权值线段树,下标为(dep),再运用差分思想。在(u->lca)路径上的线段树上,在(u)处+1,再在(lca)处-1;在(lca->v)路径上,在(v)处+1,(fa[lca])处-1,这样就保证了两条路径合并的信息。

然后节点(x)的答案即为两组线段树查询结果之和。

#include <cstdio>
#include <algorithm>
#define MAXN 300003
using namespace std;
inline int read(){
    char ch=getchar();int s=0;bool w=0;
    while((ch<'0'||ch>'9')&&(ch!='-')) ch=getchar();
    if(ch=='-'){w=1;ch=getchar();}
    while(ch>='0'&&ch<='9') s=s*10+(ch^'0'),ch=getchar();
    if(w) return -s;
    return s;
}
int head[MAXN],nxt[MAXN*2],vv[MAXN*2],tot;
inline void add_egde(int u, int v){
    vv[++tot]=v;
    nxt[tot]=head[u];
    head[u]=tot;
}
int dep[MAXN],mxs[MAXN],sz[MAXN],f[MAXN];
void dfs(int u, int fa){
    sz[u]=1;
    dep[u]=dep[fa]+1;
    f[u]=fa;
    int mxsz=-1;
    for(int i=head[u];i;i=nxt[i]){
        int v=vv[i];
        if(v==fa) continue;
        dfs(v, u);
        sz[u]+=sz[v];
        if(sz[v]>mxsz){
            mxsz=sz[v];
            mxs[u]=v;
        }
    }
}
int topf[MAXN];
void dfs2(int u, int top){
    topf[u]=top;
    if(mxs[u]==0) return;
    dfs2(mxs[u], top);
    for(int i=head[u];i;i=nxt[i]){
        int v=vv[i];
        if(v==f[u]||v==mxs[u]) continue;
        dfs2(v, v);
    }
}
int lca(int a,int b){
    while(topf[a]!=topf[b]){
        if(dep[topf[a]]<dep[topf[b]]) swap(a,b);
        a=f[topf[a]];
    }
    if(dep[a]<dep[b]) return a;
    else return b;
}
int cnt;
#define MAXM MAXN*2*18*2
int tre[MAXM],sl[MAXM],sr[MAXM];
void change(int &x, int l, int r, int pos, int val){
    if(x==0) x=++cnt;
    if(l==r){
        tre[x]+=val;
        return;
    }
    int mid=(l+r)>>1;
    if(pos<=mid) change(sl[x], l, mid, pos, val);
    else change(sr[x], mid+1, r, pos, val);
}
int query(int x, int l, int r, int pos){
    if(x==0) return 0;
    if(l==r) return tre[x];
    int mid=(l+r)>>1;
    if(pos<=mid) return query(sl[x], l, mid, pos);
    else return query(sr[x], mid+1, r, pos);
}
int merge(int a, int b, int l, int r){
    if(a==0||b==0) return a+b;
    if(l==r){
        tre[a]+=tre[b];
        return a;
    }
    int mid=(l+r)>>1;
    sl[a]=merge(sl[a], sl[b], l, mid);
    sr[a]=merge(sr[a], sr[b], mid+1, r);
    return a;
}
int n,m,w[MAXN],rot1[MAXN],rot2[MAXN];
int ans[MAXN];
void solve(int u, int fa){
    for(int i=head[u];i;i=nxt[i]){
        int v=vv[i];
        if(v==fa) continue;
        solve(v, u);
        rot1[u]=merge(rot1[u], rot1[v], 1, n*2);
        rot2[u]=merge(rot2[u], rot2[v], 1, n*2);
    }
    ans[u]=query(rot1[u], 1, n*2, dep[u]+w[u])+query(rot2[u], 1, n*2, n+w[u]-dep[u]);
}
int main(){
    n=read(),m=read();
    for(int i=1;i<n;++i){
        int u=read(),v=read();
        add_egde(u, v);
    }
    dfs(1, 0);
    dfs2(1, 1);
    for(int i=1;i<=n;++i) w[i]=read();
    for(int i=1;i<=m;++i){
        int s=read(),t=read();
        int tmp=lca(s, t);
        change(rot1[s], 1, n*2, dep[s], 1);
        change(rot1[tmp], 1, n*2, dep[s], -1);
        change(rot2[t], 1, n*2, n+dep[s]-dep[tmp]*2, 1);
        change(rot2[f[tmp]], 1, n*2, n+dep[s]-dep[tmp]*2, -1);
    }
    solve(1, 1);
    for(int i=1;i<=n;++i) printf("%d ", ans[i]);
    return 0;
}
原文地址:https://www.cnblogs.com/santiego/p/11770332.html