洛谷 P1501 [国家集训队]Tree II Link-Cut-Tree

Code:

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <string>
using namespace std;

void setIO(string a)
{
    freopen((a+".in").c_str(),"r",stdin);
    freopen((a+".out").c_str(),"w",stdout);
}

#define maxn 100009
#define ll long long 
#define mod 51061

int f[maxn], ch[maxn][2],siz[maxn],tag[maxn],sta[maxn],n,m;
ll mult[maxn], add[maxn], sumv[maxn], val[maxn];

int lson(int x)
{ 
    return ch[x][0];
}
int rson(int x)
{
    return ch[x][1];
}
int get(int x)
{
    return ch[f[x]][1]==x; 
}
int isRoot(int x)
{ 
    return !(ch[f[x]][1]==x||ch[f[x]][0]==x);
}
void mark(int x)
{ 
    if(!x)return;
    swap(ch[x][0], ch[x][1]), tag[x]^=1;
}

void pushdown(int x){
   // if(!x)return;
    if(mult[x]!=1) 
    {
        if(lson(x)) 
        {
            sumv[lson(x)]*=mult[x];
            mult[lson(x)]*=mult[x];
            add[lson(x)]*=mult[x];
            val[lson(x)]*=mult[x];

            add[lson(x)]%=mod;
            mult[lson(x)]%=mod;
            sumv[lson(x)]%=mod;
            val[lson(x)]%=mod;
        }
        if(rson(x)) 
        {
            sumv[rson(x)]*=mult[x];
            mult[rson(x)]*=mult[x];
            add[rson(x)]*=mult[x];
            val[rson(x)]*=mult[x];

            add[rson(x)]%=mod;
            mult[rson(x)]%=mod;
            sumv[rson(x)]%=mod;
            val[rson(x)]%=mod;
        }
        mult[x]=1;
    }
    if(add[x]) 
    {
        if(lson(x)) 
        {
            sumv[lson(x)]+=add[x]*siz[lson(x)];
            add[lson(x)]+=add[x];
            val[lson(x)]+=add[x];

            add[lson(x)]%=mod;
            sumv[lson(x)]%=mod;
            val[lson(x)]%=mod;
        }
        if(rson(x)) 
        {
            sumv[rson(x)]+=add[x]*siz[rson(x)];
            add[rson(x)]+=add[x];
            val[rson(x)]+=add[x];

            add[rson(x)]%=mod;
            sumv[rson(x)]%=mod;
            val[rson(x)]%=mod;
        }
        add[x]=0;
    }
    if(tag[x]) mark(ch[x][0]), mark(ch[x][1]), tag[x]=0;
}

void pushup(int x)
{
    if(!x)return;
    siz[x]=siz[lson(x)]+siz[rson(x)]+1;
    sumv[x]=(sumv[lson(x)]+sumv[rson(x)]+val[x])%mod;
}

void rotate(int o){
    int old=f[o],fold=f[old],which=get(o);
    if(!isRoot(old)) ch[fold][ch[fold][1]==old]=o;
    f[o]=fold;
    ch[old][which]=ch[o][which^1], f[ch[old][which]]=old;
    ch[o][which^1]=old,f[old]=o;
    pushup(old),pushup(o),pushup(fold);
}
void splay(int x){
    int v=0,u=x;
    sta[++v]=u;
    while(!isRoot(u)) sta[++v]=f[u],u=f[u];
    while(v) pushdown(sta[v--]);
    u=f[u];
    for(int fa;(fa=f[x])!=u;rotate(x))
        if(f[fa]!=u) rotate(get(x)==get(fa)?fa:x);
}
void Access(int x)
{ 
    for(int y=0;x;y=x,x=f[x]) splay(x), ch[x][1]=y, pushup(x); 
}
void makeRoot(int x)
{ 
    Access(x), splay(x),mark(x);
}
void split(int x,int y) 
{
    makeRoot(x), Access(y), splay(y); 
}
void Link(int x,int y)
{ 
    makeRoot(x), f[x]=y;
}
void cut(int x,int y)
{  
    makeRoot(x), Access(y), splay(y);
    f[x] = ch[y][0] = 0, pushup(y); 
}

void debug(){ for(int i=1;i<=n;++i) printf("%d %lld
",siz[i],sumv[i]); }

int main(){
    //setIO("input");
    memset(add,0,sizeof(add)), memset(mult,1,sizeof(mult));
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;++i) sumv[i]=val[i]=1;
    for(int i=1;i<n;++i)
    { 
        int a,b;  
        scanf("%d%d",&a,&b);
        Link(a,b); 
    }

    for(int i=1;i<=m;++i)
    {
        char opt[10];
        int a,b,c,d;
        scanf("%s",opt);
        if(opt[0]=='+') 
        {
            scanf("%d%d%d",&a,&b,&c);
            split(a,b);
            sumv[b]+=siz[b]*c;
            add[b]+=c;
            val[b]+=c;

            sumv[b]%=mod;
            add[b]%=mod;
            val[b]%=mod;
        }
        if(opt[0]=='-')
        {
            scanf("%d%d%d%d",&a,&b,&c,&d);
            cut(a,b);
            Link(c,d);
        }
        if(opt[0]=='*')
        {
            scanf("%d%d%d",&a,&b,&c);
            split(a,b);
            sumv[b]*=c;
            add[b]*=c;
            mult[b]*=c;
            val[b]*=c;

            sumv[b]%=mod;
            add[b]%=mod;
            mult[b]%=mod;
            val[b]%=mod;
        }
        if(opt[0]=='/')
        {
            scanf("%d%d",&a,&b);
            split(a,b);
            printf("%lld
",sumv[b]%mod);
        }
    }
    return 0;
}

  

原文地址:https://www.cnblogs.com/guangheli/p/10028065.html