A Simple Problem On A Tree

题意:

给出一棵树,树上每个点有一个权值,有如下操作:
1.输入 (u,v,w),把点 (u,v) 之间路径上的点的权值全部赋值为 (w)
2.输入 (u,v,w),把点 (u,v) 之间路径上的点的权值全部加上 (w)
3.输入 (u,v,w),把点 (u,v) 之间路径上的点的权值全部乘以 (w)
4.输入 (u,v),求出 (sum_{x}{W_x^3}),即路径上所有点的立方和;
数据范围:$ 0 leq w leq 1,000,000,000,1≤N≤100,000$
传送门

分析:

树链剖分,线段树维护区间(平方和,立方和)修改区间(加,赋值,乘)。

维护立方和,要同时维护平方和与一次方和。
同时处理加法和乘法处理时,先处理乘法,同时注意乘法对加法的影响,如:((a+b)*c=a*c+b*c)
另外,赋值的处理先于乘法和加法。
维护平方和和立方和,根据:
((w+d)^3=w^3+(d^3+3*w^2*d+3*w*d^2),(w*d)^3=w^3*d^3)
((w+d)^2=w^2+(d^2+2*w*d),(w*d)^2=w^2*d^2)
另外立方和的处理先于平方和,平方和的处理先于一次方的处理。
注意取模。

代码:

#include <bits/stdc++.h>
#define pb push_back
using namespace std;
typedef long long ll;
const ll mod=1e9+7;
const int N=1e5+5;
struct node
{
    ll val,sqa,cub;
}tree[N<<2];
struct tag
{
    ll add,mul,asg;
}lazy[N<<2];
vector<int>pic[N];
ll w[N];
int dfn[N],rnk[N],sz[N],son[N],depth[N],fa[N],top[N];
int n;
/**重链剖分**/
void dfs1(int v,int p,int d)
{
    sz[v]=1;
    son[v]=0;
    depth[v]=d;
    fa[v]=p;
    for(int i=0;i<pic[v].size();i++)
    {
        int u=pic[v][i];
        if(u==p) continue;
        dfs1(u,v,d+1);
        sz[v]+=sz[u];
        if(sz[u]>sz[son[v]])
            son[v]=u;
    }
}
void dfs2(int v,int p,int tp,int &cnt)
{
    dfn[v]=++cnt;
    rnk[cnt]=v;
    top[v]=tp;
    if(!son[v]) return;
    dfs2(son[v],v,tp,cnt);
    for(int i=0;i<pic[v].size();i++)
    {
        int u=pic[v][i];
        if(u==p||u==son[v]) continue;
        dfs2(u,v,u,cnt);
    }
}
/**线段树**/
void pushup(int rt)
{
    tree[rt].val=(tree[rt<<1].val+tree[rt<<1|1].val)%mod;
    tree[rt].sqa=(tree[rt<<1].sqa+tree[rt<<1|1].sqa)%mod;
    tree[rt].cub=(tree[rt<<1].cub+tree[rt<<1|1].cub)%mod;
}
void pushdown(int rt,int ln,int rn)
{//先判断是否是赋值:
    //先乘后加:
    if(lazy[rt].asg)//赋值要优先于加和乘处理
    {
        ll t=lazy[rt].asg%mod;
        tree[rt<<1].cub=t*t%mod*t%mod*ln%mod;
        tree[rt<<1|1].cub=t*t%mod*t%mod*rn%mod;
        tree[rt<<1].sqa=t*t%mod*ln%mod;
        tree[rt<<1|1].sqa=t*t%mod*rn%mod;
        tree[rt<<1].val=t%mod*ln%mod;
        tree[rt<<1|1].val=t%mod*rn%mod;
        lazy[rt<<1]=tag{0,1,t};
        lazy[rt<<1|1]=tag{0,1,t};
        lazy[rt].asg=0;//赋值处理后不能把lazy[rt]的加和乘的标记给取消
    }
    if(lazy[rt].mul>1)
    {
        ll t=lazy[rt].mul*lazy[rt].mul%mod*lazy[rt].mul%mod;
        tree[rt<<1].cub=tree[rt<<1].cub*t%mod;
        tree[rt<<1|1].cub=tree[rt<<1|1].cub*t%mod;

        t=lazy[rt].mul*lazy[rt].mul%mod;
        tree[rt<<1].sqa=tree[rt<<1].sqa*t%mod;
        tree[rt<<1|1].sqa=tree[rt<<1|1].sqa*t%mod;

        tree[rt<<1].val=tree[rt<<1].val*lazy[rt].mul%mod;
        tree[rt<<1|1].val=tree[rt<<1|1].val*lazy[rt].mul%mod;

        lazy[rt<<1].add=lazy[rt<<1].add*lazy[rt].mul%mod;
        lazy[rt<<1].mul=lazy[rt<<1].mul*lazy[rt].mul%mod;
        lazy[rt<<1|1].add=lazy[rt<<1|1].add*lazy[rt].mul%mod;
        lazy[rt<<1|1].mul=lazy[rt<<1|1].mul*lazy[rt].mul%mod;
        lazy[rt].mul=1;
    }
    if(lazy[rt].add>0)
    {//(w+d)^3=w^3+(d^3+3*w^2*d+3*w*d^2)
        ll t=lazy[rt].add*lazy[rt].add%mod*lazy[rt].add%mod;
        ll a=3LL*tree[rt<<1].sqa%mod*lazy[rt].add%mod;
        ll b=3LL*tree[rt<<1].val%mod*lazy[rt].add%mod*lazy[rt].add%mod;
        tree[rt<<1].cub=(tree[rt<<1].cub+t*ln%mod+a+b)%mod;

        a=3LL*tree[rt<<1|1].sqa%mod*lazy[rt].add%mod;
        b=3LL*tree[rt<<1|1].val%mod*lazy[rt].add%mod*lazy[rt].add%mod;
        tree[rt<<1|1].cub=(tree[rt<<1|1].cub+t*rn%mod+a+b)%mod;
        //(w+d)^2=w^2+(d^2+2*w*d)
        t=lazy[rt].add*lazy[rt].add%mod;
        a=2LL*tree[rt<<1].val%mod*lazy[rt].add%mod;
        tree[rt<<1].sqa=(tree[rt<<1].sqa+t*ln%mod+a)%mod;

        a=2LL*tree[rt<<1|1].val%mod*lazy[rt].add%mod;
        tree[rt<<1|1].sqa=(tree[rt<<1|1].sqa+t*rn%mod+a)%mod;
        //
        tree[rt<<1].val=(tree[rt<<1].val+lazy[rt].add*ln%mod)%mod;
        tree[rt<<1|1].val=(tree[rt<<1|1].val+lazy[rt].add*rn%mod)%mod;

        lazy[rt<<1].add=(lazy[rt<<1].add+lazy[rt].add)%mod;
        lazy[rt<<1|1].add=(lazy[rt<<1|1].add+lazy[rt].add)%mod;
        lazy[rt].add=0;
    }
}
void build(int l,int r,int rt)
{
    lazy[rt]=tag{0,1,0};
    if(l==r)
    {
        int t=rnk[l];
        tree[rt].val=w[t];
        tree[rt].sqa=w[t]*w[t]%mod;
        tree[rt].cub=w[t]*w[t]%mod*w[t]%mod;
        return;
    }
    int mid=(l+r)>>1;
    build(l,mid,rt<<1);
    build(mid+1,r,rt<<1|1);
    pushup(rt);
}
void update(int l,int r,int L,int R,int rt,ll num,int f)
{
    if(L<=l&&r<=R)
    {
        if(f==1)//直接赋值
        {
            tree[rt].val=num%mod*(r-l+1)%mod;
            tree[rt].sqa=num%mod*num%mod*(r-l+1)%mod;
            tree[rt].cub=num%mod*num%mod*num%mod*(r-l+1)%mod;
            lazy[rt]=tag{0,1,num};
        }
        else if(f==2)//加
        {//立方先加
            //立方:(w+d)^3=w^3+(d^3+3*w^2*d+3*w*d^2)
            ll t=(num%mod*num%mod*num%mod*(r-l+1)%mod+3LL*tree[rt].sqa%mod*num%mod+3LL*num*num%mod*tree[rt].val%mod)%mod;
            tree[rt].cub=(tree[rt].cub+t)%mod;
            //平方:w^2+d^2+2*w*d=(w+d)^2
            t=(num%mod*num%mod*(r-l+1)%mod+2LL*num%mod*tree[rt].val%mod)%mod;
            tree[rt].sqa=(tree[rt].sqa+t)%mod;
            //一次方:
            tree[rt].val=(tree[rt].val+num*(r-l+1)%mod)%mod;
            
            lazy[rt].add=(lazy[rt].add+num)%mod;
        }
        else//乘
        {//注意对加法的影响
            //立方:(w*d)^3=(w^3)*(d^3)
            tree[rt].cub=(tree[rt].cub*num%mod*num%mod*num)%mod;
            //平方:(w*d)^2=w^2*d^2
            tree[rt].sqa=(tree[rt].sqa*num%mod*num)%mod;
            //一次方:
            tree[rt].val=(tree[rt].val*num)%mod;
            lazy[rt].add=lazy[rt].add*num%mod;
            lazy[rt].mul=lazy[rt].mul*num%mod;
        }
        return;
    }
    int mid=(l+r)>>1;
    pushdown(rt,mid-l+1,r-mid);
    if(L<=mid)
        update(l,mid,L,R,rt<<1,num,f);
    if(R>mid)
        update(mid+1,r,L,R,rt<<1|1,num,f);
    pushup(rt);
}
ll query(int l,int r,int L,int R,int rt)
{
    if(L<=l&&r<=R)
        return tree[rt].cub%mod;
    int mid=(l+r)>>1;
    pushdown(rt,mid-l+1,r-mid);
    ll ans=0;
    if(L<=mid)
        ans=(ans+query(l,mid,L,R,rt<<1))%mod;
    if(R>mid)
        ans=(ans+query(mid+1,r,L,R,rt<<1|1))%mod;
    return ans;
}
void change(int u,int v,ll w,int f)
{
    while(top[u]!=top[v])
    {
        if(depth[top[v]]<depth[top[u]]) swap(u,v);
        update(1,n,dfn[top[v]],dfn[v],1,w,f);
        v=fa[top[v]];
    }
    if(depth[v]<depth[u]) swap(u,v);
    update(1,n,dfn[u],dfn[v],1,w,f);
}
ll ask(int u,int v)
{
    ll res=0;
    while(top[u]!=top[v])
    {
        if(depth[top[v]]<depth[top[u]]) swap(u,v);
        res=(res+query(1,n,dfn[top[v]],dfn[v],1))%mod;
        v=fa[top[v]];
    }
    if(depth[v]<depth[u]) swap(u,v);
    res=(res+query(1,n,dfn[u],dfn[v],1))%mod;
    return res;
}
/**初始化**/
void init()
{
    for(int i=1;i<=n;i++)
        pic[i].clear();
}
int main()
{
    int t,x,y,cas=0,q;
    scanf("%d",&t);
    while(t--)
    {
        scanf("%d",&n);
        init();
        for(int i=1;i<n;i++)
        {
            scanf("%d%d",&x,&y);
            pic[x].pb(y);
            pic[y].pb(x);
        }
        for(int i=1;i<=n;i++)
            scanf("%lld",&w[i]);
        int cnt=0;
        dfs1(1,0,0);
        dfs2(1,0,1,cnt);
        build(1,n,1);
        scanf("%d",&q);
        printf("Case #%d:
",++cas);
        while(q--)
        {
            int op,u,v;
            ll we;
            scanf("%d",&op);
            scanf("%d%d",&u,&v);
            if(op==4)
                printf("%lld
",ask(u,v));
            else
            {
                scanf("%lld",&we);
                change(u,v,we,op);
            }
        }
    }
    return 0;
}
原文地址:https://www.cnblogs.com/1024-xzx/p/12859340.html