bzoj3901

题解:

就是按照常规的合并

期望有一点麻烦

首先计算全部的和

再减去有多少种

具体看看http://blog.csdn.net/PoPoQQQ/article/category/2542261这个博客吧

代码:

#include<bits/stdc++.h>
using namespace std;
#define pa t[x].fa
#define lc t[x].ch[0]
#define rc t[x].ch[1]
const int N=5e4+5;
typedef long long ll;
int read()
{
    char c=getchar();int x=0,f=1;
    while(c<'0'||c>'9'){if (c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}
    return x*f;
}
struct node
{
    int ch[2],fa,rev;
    ll add,lsum,rsum,sum,exp,w,size;
}t[N];
int wh(int x){return t[pa].ch[1]==x;}
int isRoot(int x){return t[pa].ch[0]!=x&&t[pa].ch[1]!=x;}
void update(int x)
{
    t[x].size=t[lc].size+t[rc].size+1;
    t[x].sum=t[lc].sum+t[rc].sum+t[x].w;
    t[x].lsum=t[lc].lsum+t[x].w*(t[lc].size+1)+t[rc].lsum+t[rc].sum*(t[lc].size+1);
    t[x].rsum=t[rc].rsum+t[x].w*(t[rc].size+1)+t[lc].rsum+t[lc].sum*(t[rc].size+1);
    t[x].exp=t[lc].exp+t[rc].exp
    +t[lc].lsum*(t[rc].size+1)+t[rc].rsum*(t[lc].size+1)
    +t[x].w*(t[lc].size+1)*(t[rc].size+1);
}
ll cal1(ll x){return x*(x+1)/2;}
ll cal2(ll x){return x*(x+1)*(x+2)/6;}
void paint(int x,ll d)
{
    t[x].w+=d;
    t[x].add+=d;
    t[x].sum+=d*t[x].size;
    t[x].lsum+=d*cal1(t[x].size);
    t[x].rsum+=d*cal1(t[x].size);
    t[x].exp+=d*cal2(t[x].size);
}
void rever(int x)
{
    swap(lc,rc);
    swap(t[x].lsum,t[x].rsum);
    t[x].rev^=1;
}
void pushDown(int x)
{
    if (t[x].rev)
     {
        rever(lc);
        rever(rc);
        t[x].rev=0;
     }
    if (t[x].add)
     {
        paint(lc,t[x].add);
        paint(rc,t[x].add);
        t[x].add=0;
     }
}
void rotate(int x)
{
    int f=t[x].fa,g=t[f].fa,c=wh(x);
    if (!isRoot(f)) t[g].ch[wh(f)]=x;t[x].fa=g;
    t[f].ch[c]=t[x].ch[c^1];t[t[f].ch[c]].fa=f;
    t[x].ch[c^1]=f;t[f].fa=x;
    update(f);update(x);
}
int st[N],top;
void splay(int x)
{
    top=0;st[++top]=x;
    for (int i=x;!isRoot(i);i=t[i].fa) st[++top]=t[i].fa;
    for (int i=top;i>=1;i--) pushDown(st[i]);
    for (;!isRoot(x);rotate(x))
     if (!isRoot(pa)) rotate(wh(x)==wh(pa)?pa:x);
}
void Access(int x)
{
    for (int y=0;x;y=x,x=pa)
     {
        splay(x);
        rc=y;
        update(x);
     }
}
void MakeR(int x){Access(x);splay(x);rever(x);}
int FindR(int x){Access(x);splay(x);while(lc) x=lc;return x;}
void Link(int x,int y){MakeR(x);t[x].fa=y;}
void Cut(int x,int y)
{
    MakeR(x);Access(y);splay(y);
    t[y].ch[0]=t[x].fa=0;
    update(y);
}
void Add(int x,int y,int d)
{
    if (FindR(x)!=FindR(y)) return;
    MakeR(x);Access(y);splay(y);
    paint(y,d);
}
ll gcd(ll a,ll b){return b==0?a:gcd(b,a%b);}
void Que(int x,int y)
{
    if (FindR(x)!=FindR(y)){puts("-1");return;}
    MakeR(x);Access(y);splay(y);
    ll a=t[y].exp,b=t[y].size*(t[y].size+1)/2;
    ll g=gcd(a,b);
    printf("%lld/%lld
",a/g,b/g);
}
int n,Q,a,op,x,y,d;
int main()
{
    n=read();Q=read();
    for (int i=1;i<=n;i++)
     {
        a=read();
        t[i].size=1;
        t[i].w=t[i].lsum=t[i].rsum=t[i].sum=t[i].exp=a;
     }
    for (int i=1;i<=n-1;i++) x=read(),y=read(),Link(x,y);
    while(Q--)
     {
        op=read();x=read();y=read();
        if (op==1) if (FindR(x)==FindR(y)) Cut(x,y);
        if (op==2) if (FindR(x)!=FindR(y)) Link(x,y);
        if (op==3) d=read(),Add(x,y,d);
        if (op==4) Que(x,y);
     }
}
原文地址:https://www.cnblogs.com/xuanyiming/p/8024927.html