NOIP 模拟 $90; m 校门外歪脖树上的鸽子$

题解 (by;zjvarphi)

树上问题,采用树链剖分。

考虑一下线段树区间查询的过程,放到这题上就是:递归下去的从儿子的区间交点分开。

如果递归右儿子,那么就会对左儿子造成贡献,反之同理。

具体实现就是开两棵线段树,一棵表示从右链递归,对左儿子的贡献,另一棵相反。

求出每个点从下往上走左子树到的深度最浅的点,和走右子树到的。

注意根节点的兄弟是它自己。

Code
#include<bits/stdc++.h>
#define ri signed
#define pd(i) ++i
#define bq(i) --i
#define func(x) std::function<x>
namespace IO{
    char buf[1<<21],*p1=buf,*p2=buf;
    #define gc() p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?(-1):*p1++
    #define debug1(x) std::cerr << #x"=" << x << ' '
    #define debug2(x) std::cerr << #x"=" << x << std::endl
    #define Debug(x) assert(x)
    struct nanfeng_stream{
        template<typename T>inline nanfeng_stream &operator>>(T &x) {
            bool f=false;x=0;char ch=gc();
            while(!isdigit(ch)) f|=ch=='-',ch=gc();
            while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=gc();
            return x=f?-x:x,*this;
        } 
    }cin;
}
using IO::cin;
namespace nanfeng{
    #define mk std::make_pair
    #define FI FILE *IN
    #define FO FILE *OUT
    template<typename T>inline T cmax(T x,T y) {return x>y?x:y;}
    template<typename T>inline T cmin(T x,T y) {return x>y?y:x;}
    using ull=long long;
    static const int N=4e5+7;
    int ch[N][2],ws[N],top[N],siz[N],hs[N],le[N],ll[N],rl[N],ul[N],ur[N],br[N],fa[N],dfn[N],bc[N],dep[N],tot,X,Y,n,m,opt,l,r,w,rt,typ,al;
    std::map<std::pair<int,int>,int> mp;
    func(void(int)) dfs1=[](int x) {
        siz[x]=le[x]=1;
        int sl=ch[x][0],sr=ch[x][1];
        if (!fa[x]) ul[x]=ur[x]=x;
        else ul[x]=!ws[x]?ul[fa[x]]:x,ur[x]=ws[x]?ur[fa[x]]:x;
        ll[x]=rl[x]=x;
        dep[x]=dep[fa[x]]+1;
        if (x>n) {
            dfs1(sl),dfs1(sr);
            siz[x]+=siz[sl]+siz[sr]; 
            hs[x]=siz[sl]>siz[sr]?sl:sr;
            ll[x]=ll[sl],rl[x]=rl[sr];
            le[x]=le[sl]+le[sr];
        }
        mp[mk(ll[x],rl[x])]=x;
    };
    func(void(int,int)) dfs2=[](int x,int tp) {
        dfn[bc[tot]=x]=++tot;
        top[x]=tp;
        if (hs[x]) dfs2(hs[x],tp);
        if (x>n) {
            int v=ch[x][0]==hs[x]?ch[x][1]:ch[x][0];
            dfs2(v,v);
        }
    };
    struct Segmenttree{
        #define ls(x) (x<<1)
        #define rs(x) (x<<1|1)
        #define up(x) T[x].sum=T[ls(x)].sum+T[rs(x)].sum
        struct segmenttree{ull sum,le,lz;}T[N<<2];
        func(void(int,int,int)) build=[&](int x,int l,int r) {
            if (l==r) {
                int k=bc[l];
                if (ws[k]!=typ) T[x].le=le[br[k]];
                return;
            }
            int mid=(l+r)>>1;
            build(ls(x),l,mid);
            build(rs(x),mid+1,r);
            T[x].le=T[ls(x)].le+T[rs(x)].le;
        };
        func(void(int)) down=[&](int x) {
            if (!T[x].lz) return;
            T[ls(x)].lz+=T[x].lz,T[ls(x)].sum+=T[ls(x)].le*T[x].lz;
            T[rs(x)].lz+=T[x].lz,T[rs(x)].sum+=T[rs(x)].le*T[x].lz;
            T[x].lz=0;
        };
        func(void(int,int,int,int,int,int)) update=[&](int x,int k,int l,int r,int lt,int rt) {
            if (l<=lt&&rt<=r) return T[x].lz+=k,T[x].sum+=T[x].le*k,void();
            int mid=(lt+rt)>>1;
            down(x);
            if (l<=mid) update(ls(x),k,l,r,lt,mid);
            if (r>mid) update(rs(x),k,l,r,mid+1,rt);
            up(x);
        };
        func(ull(int,int,int,int,int)) query=[&](int x,int l,int r,int lt,int rt) {
            if (l<=lt&&rt<=r) return T[x].sum;
            int mid=(lt+rt)>>1;
            ull res=0;
            down(x);
            if (l<=mid) res+=query(ls(x),l,r,lt,mid);
            if (r>mid) res+=query(rs(x),l,r,mid+1,rt);
            return res;
        };
    }T[2];
    auto Aupdate=[](int x,int w) {
        if (ws[x]) T[0].update(1,w,dfn[br[x]],dfn[br[x]],1,al);
        else T[1].update(1,w,dfn[br[x]],dfn[br[x]],1,al);
    };
    auto Aquery=[](int x) {
        if (ws[x]) return T[0].query(1,dfn[br[x]],dfn[br[x]],1,al);
        else return T[1].query(1,dfn[br[x]],dfn[br[x]],1,al);
    };
    auto update=[](const int opt,int x,int v,int w) {
        while(top[x]!=top[v]) {
            T[opt].update(1,w,dfn[top[x]],dfn[x],1,al);
            x=fa[top[x]];
        }
        if (x!=v) T[opt].update(1,w,dfn[v]+1,dfn[x],1,al);
    };
    auto query=[](const int opt,int x,int v) {
        ull res=0;
        while(top[x]!=top[v]) {
            res+=T[opt].query(1,dfn[top[x]],dfn[x],1,al);
            x=fa[top[x]];
        }
        if (x!=v) res+=T[opt].query(1,dfn[v]+1,dfn[x],1,al);
        return res;
    };
    auto Getlca=[](int x,int v) {
        while(top[x]!=top[v]) {
            if (dep[top[x]]<dep[top[v]]) std::swap(x,v);
            x=fa[top[x]];
        }
        return dep[x]<dep[v]?x:v;
    };
    auto find=[](int x,int v) {
        x=top[x];
        while(x!=top[v]) {
            if (fa[x]==v) return x;
            x=top[fa[x]];
        }
        return hs[v];
    };
    inline int main() {
        FI=freopen("pigeons.in","r",stdin);
        FO=freopen("pigeons.out","w",stdout);
        cin >> n >> m;
        al=(n<<1)-1;
        memset(ws,-1,sizeof(ws));
        for (ri i(1);i<n;pd(i)) {
            cin >> X >> Y;
            ch[n+i][0]=X,ch[n+i][1]=Y;
            ws[X]=0,ws[Y]=1;
            fa[X]=fa[Y]=n+i;
            br[X]=Y,br[Y]=X;
        }
        for (ri i(1);i<=(n<<1)-1;pd(i)) if (!fa[i]) {rt=i;break;}
        br[rt]=rt;
        dfs1(rt),dfs2(rt,rt);
        typ=1,T[0].build(1,1,al);
        typ=0,T[1].build(1,1,al);
        for (ri i(1);i<=m;pd(i)) {
            cin >> opt >> l >> r;
            int k=-1;
            std::pair<int,int> tmp=mk(l,r);
            if (mp.find(tmp)!=mp.end()) k=mp[tmp];
            if (opt==1) {
                cin >> w;
                if (k!=-1) Aupdate(k,w);
                else {
                    int nx=ul[l],ny=ur[r],lca=Getlca(l,r);
                    if (dep[nx]<=dep[lca]) Aupdate(find(l,lca),w);
                    else Aupdate(nx,w),update(0,nx,find(l,lca),w);
                    if (dep[ny]<=dep[lca]) Aupdate(find(r,lca),w);
                    else Aupdate(ny,w),update(1,ny,find(r,lca),w);
                }
            } else {
                ull res=0;
                if (k!=-1) res=Aquery(k);
                else {
                    int nx=ul[l],ny=ur[r],lca=Getlca(l,r);
                    // debug1(nx),debug1(ny),debug2(lca);
                    if (dep[nx]<=dep[lca]) res+=Aquery(find(l,lca));
                    else res+=Aquery(nx)+query(0,nx,find(l,lca));
                    if (dep[ny]<=dep[lca]) res+=Aquery(find(r,lca));
                    else res+=Aquery(ny)+query(1,ny,find(r,lca));
                }
                printf("%llu
",res);
            }
        }
        return 0;
    }
}
int main() {return nanfeng::main();}
原文地址:https://www.cnblogs.com/nanfeng-blog/p/15515305.html