2019ICPC上海站 F-A Simple Problem On A Tree 线段树+树链剖分

2019ICPC上海站 F-A Simple Problem On A Tree

题意

给定一颗(n)个结点的树,每个点都有一个点权(a_i),有(Q)次询问,询问有四种:

  • (1~u~v~w),将(u)(v)的路径上的点的点权赋值为(w)
  • (2~u~v~w),将(u)(v)的路径上的点的点权加上(w)
  • (3~u~v~w),将(u)(v)的路径上的点的点权乘上(w)
  • (4~u~v),询问(u)(v)的路径上的点的点权立方和。

分析

((x+w)^3=x^3+3x^2w+3xw^2+w^3,(x+w)^2=x^2+2xw+w^2),这样展开一下,我们就发现只需要在线段树上维护区间立方和(x^3),区间平方和(x^2),区间和(x),对于修改操作需要三个lazy标记,分别记录赋值、加权和、乘积,下推标记的时候要注意优先级,对于赋值操作,要先将另两个标记初始化,再打上赋值标记,对于乘积操作,要把加权和标记也对应的乘上(w)。由于是树上操作,再套个树链剖分就好了。

Code

#include<bits/stdc++.h>
#define rep(i,x,n) for(int i=x;i<=n;i++)
#define per(i,n,x) for(int i=n;i>=x;i--)
#define sz(a) int(a.size())
#define rson mid+1,r,p<<1|1
#define pii pair<int,int>
#define lson l,mid,p<<1
#define ll long long
#define pb push_back
#define mp make_pair
#define se second
#define fi first
using namespace std;
const double eps=1e-8;
const int mod=1e9+7;
const int N=1e5+10;
const int inf=1e9;
int T,Q,n;
vector<int>g[N];
int sz[N],f[N],d[N],top[N],son[N],p[N],id[N],tot;
ll a[N],sum[N<<2][3],tag[N<<2][3];
ll cal(ll x,int i){
    ll y=1;
    for(int j=0;j<i;j++) y=y*x%mod;
    return y;
}
void pp(int p){
    for(int i=0;i<3;i++) sum[p][i]=(sum[p<<1][i]+sum[p<<1|1][i])%mod;
}
void bd(int l,int r,int p){
    tag[p][0]=1,tag[p][1]=0,tag[p][2]=-1;
    if(l==r){
        for(int i=0;i<3;i++) sum[p][i]=cal(a[id[l]],i+1);
        return;
    }
    int mid=l+r>>1;
    bd(lson);bd(rson);
    pp(p);
}
void change(int l,int r,int p,ll k1,ll k2,ll k3){
    if(k3!=-1){
        for(int i=0;i<3;i++) sum[p][i]=(r-l+1)*cal(k3,i+1)%mod;
        tag[p][0]=1,tag[p][1]=0,tag[p][2]=k3;
    }
    for(int i=0;i<3;i++) sum[p][i]=sum[p][i]*cal(k1,i+1)%mod;
    sum[p][2]=(sum[p][2]+3*sum[p][1]*k2%mod+3*sum[p][0]*cal(k2,2)%mod+(r-l+1)*cal(k2,3)%mod)%mod;
    sum[p][1]=(sum[p][1]+2*sum[p][0]*k2%mod+(r-l+1)*cal(k2,2)%mod)%mod;
    sum[p][0]=(sum[p][0]+(r-l+1)*k2%mod)%mod;
    tag[p][0]=tag[p][0]*k1%mod;
    tag[p][1]=(tag[p][1]*k1%mod+k2%mod)%mod;
}
void up(int dl,int dr,int l,int r,int p,ll k1,ll k2,ll k3){
    if(l==dl&&r==dr){
        change(l,r,p,k1,k2,k3);
        return;
    }
    int mid=l+r>>1;
    if(tag[p][0]!=1||tag[p][1]!=0||tag[p][2]!=-1){
        change(lson,tag[p][0],tag[p][1],tag[p][2]);
        change(rson,tag[p][0],tag[p][1],tag[p][2]);
        tag[p][0]=1,tag[p][1]=0,tag[p][2]=-1;
    }
    if(dr<=mid) up(dl,dr,lson,k1,k2,k3);
    else if(dl>mid) up(dl,dr,rson,k1,k2,k3);
    else up(dl,mid,lson,k1,k2,k3),up(mid+1,dr,rson,k1,k2,k3);
    pp(p);
}
ll qy(int dl,int dr,int l,int r,int p){
    if(l==dl&&r==dr) return sum[p][2];
    int mid=l+r>>1;
    if(tag[p][0]!=1||tag[p][1]!=0||tag[p][2]!=-1){
        change(lson,tag[p][0],tag[p][1],tag[p][2]);
        change(rson,tag[p][0],tag[p][1],tag[p][2]);
        tag[p][0]=1,tag[p][1]=0,tag[p][2]=-1;
    }
    if(dr<=mid) return qy(dl,dr,lson);
    else if(dl>mid) return qy(dl,dr,rson);
    else return (qy(dl,mid,lson)+qy(mid+1,dr,rson))%mod;
}
void dfs(int u){
    sz[u]=1;d[u]=d[f[u]]+1;
    for(int x:g[u]){
        if(x==f[u]) continue;
        f[x]=u;
        dfs(x);
        sz[u]+=sz[x];
        if(sz[x]>sz[son[u]]) son[u]=x;
    }
}
void dfs1(int u,int t){
    top[u]=t;p[u]=++tot;id[tot]=u;
    if(son[u]) dfs1(son[u],t);
    for(int x:g[u]){
        if(x==son[u]||x==f[u]) continue;
        dfs1(x,x);
    }
}
void modify(int x,int y,ll k1,ll k2,ll k3){
    while(top[x]!=top[y]){
        if(d[top[x]]<d[top[y]]) swap(x,y);
        up(p[top[x]],p[x],1,n,1,k1,k2,k3);
        x=f[top[x]];
    }
    if(d[x]<d[y]) swap(x,y);
    up(p[y],p[x],1,n,1,k1,k2,k3);
}
ll solve(int x,int y){
    ll ret=0;
    while(top[x]!=top[y]){
        if(d[top[x]]<d[top[y]]) swap(x,y);
        ret=(ret+qy(p[top[x]],p[x],1,n,1))%mod;
        x=f[top[x]];
    }
    if(d[x]<d[y]) swap(x,y);
    ret=(ret+qy(p[y],p[x],1,n,1))%mod;
    return ret;
}
int main(){
    scanf("%d",&T);
    for(int cas=1;cas<=T;cas++){
        tot=0;
        scanf("%d",&n);
        for(int i=2,x,y;i<=n;i++){
            scanf("%d%d",&x,&y);
            g[x].pb(y);g[y].pb(x);
        }
        for(int i=1;i<=n;i++) scanf("%lld",&a[i]);
        dfs(1);dfs1(1,1);
        bd(1,n,1);
        scanf("%d",&Q);
        printf("Case #%d:
",cas);
        while(Q--){
            int op,u,v,w;
            scanf("%d%d%d",&op,&u,&v);
            if(op==1){
                scanf("%d",&w);
                modify(u,v,1,0,w);
            }else if(op==2){
                scanf("%d",&w);
                modify(u,v,1,w,-1);
            }else if(op==3){
                scanf("%d",&w);
                modify(u,v,w,0,-1);
            }else{
                printf("%lld
",solve(u,v));
            }
        }
        for(int i=1;i<=n;i++){
            g[i].clear();
            son[i]=0;
        }
    }
    return 0;
}
原文地址:https://www.cnblogs.com/xyq0220/p/13967190.html